diff --git a/rag/embedder.go b/rag/embedder.go index 13f6a6e..39f4b5c 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -308,6 +308,19 @@ func (e *ONNXEmbedder) getModelPath() string { return e.modelPath } +func (e *ONNXEmbedder) Destroy() error { + e.mu.Lock() + defer e.mu.Unlock() + if e.session != nil { + if err := e.session.Destroy(); err != nil { + return fmt.Errorf("failed to destroy ONNX session: %w", err) + } + e.session = nil + e.logger.Info("ONNX session destroyed, VRAM freed") + } + return nil +} + func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { if err := e.ensureInitialized(); err != nil { return nil, err diff --git a/rag/rag.go b/rag/rag.go index fa30303..d64a3e1 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -12,6 +12,7 @@ import ( "sort" "strings" "sync" + "time" "github.com/neurosnap/sentences/english" ) @@ -32,6 +33,8 @@ type RAG struct { storage *VectorStorage mu sync.Mutex fallbackMsg string + idleTimer *time.Timer + idleTimeout time.Duration } func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) { @@ -58,6 +61,7 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) { embedder: embedder, storage: NewVectorStorage(l, s), fallbackMsg: fallbackMsg, + idleTimeout: 30 * time.Second, } // Note: Vector tables are created via database migrations, not at runtime @@ -187,6 +191,7 @@ func (r *RAG) LoadRAG(fpath string) error { } } r.logger.Debug("finished writing vectors", "batches", batchCount) + r.resetIdleTimer() select { case LongJobStatusCh <- FinishedRAGStatus: default: @@ -196,10 +201,12 @@ func (r *RAG) LoadRAG(fpath string) error { } func (r *RAG) LineToVector(line string) ([]float32, error) { + r.resetIdleTimer() return r.embedder.Embed(line) } func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { + r.resetIdleTimer() return r.storage.SearchClosest(emb.Embedding) } @@ -208,6 +215,7 @@ func (r *RAG) ListLoaded() ([]string, error) { } func (r *RAG) RemoveFile(filename string) error { + r.resetIdleTimer() return r.storage.RemoveEmbByFileName(filename) } @@ -471,3 +479,38 @@ func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error { func GetInstance() *RAG { return ragInstance } + +func (r *RAG) resetIdleTimer() { + if r.idleTimer != nil { + r.idleTimer.Stop() + } + r.idleTimer = time.AfterFunc(r.idleTimeout, func() { + r.freeONNXMemory() + }) +} + +func (r *RAG) freeONNXMemory() { + r.mu.Lock() + defer r.mu.Unlock() + if onnx, ok := r.embedder.(*ONNXEmbedder); ok { + if err := onnx.Destroy(); err != nil { + r.logger.Error("failed to free ONNX memory", "error", err) + } else { + r.logger.Info("freed ONNX VRAM after idle timeout") + } + } +} + +func (r *RAG) Destroy() { + r.mu.Lock() + defer r.mu.Unlock() + if r.idleTimer != nil { + r.idleTimer.Stop() + r.idleTimer = nil + } + if onnx, ok := r.embedder.(*ONNXEmbedder); ok { + if err := onnx.Destroy(); err != nil { + r.logger.Error("failed to destroy ONNX embedder", "error", err) + } + } +}