Enha (onnx): unload model if noop for 30s

This commit is contained in:
Grail Finder
2026-03-06 09:32:45 +03:00
parent d2caebdb4f
commit 4ef0a21511
2 changed files with 56 additions and 0 deletions

View File

@@ -308,6 +308,19 @@ func (e *ONNXEmbedder) getModelPath() string {
return e.modelPath return e.modelPath
} }
func (e *ONNXEmbedder) Destroy() error {
e.mu.Lock()
defer e.mu.Unlock()
if e.session != nil {
if err := e.session.Destroy(); err != nil {
return fmt.Errorf("failed to destroy ONNX session: %w", err)
}
e.session = nil
e.logger.Info("ONNX session destroyed, VRAM freed")
}
return nil
}
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err := e.ensureInitialized(); err != nil { if err := e.ensureInitialized(); err != nil {
return nil, err return nil, err

View File

@@ -12,6 +12,7 @@ import (
"sort" "sort"
"strings" "strings"
"sync" "sync"
"time"
"github.com/neurosnap/sentences/english" "github.com/neurosnap/sentences/english"
) )
@@ -32,6 +33,8 @@ type RAG struct {
storage *VectorStorage storage *VectorStorage
mu sync.Mutex mu sync.Mutex
fallbackMsg string fallbackMsg string
idleTimer *time.Timer
idleTimeout time.Duration
} }
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) { func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
@@ -58,6 +61,7 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
embedder: embedder, embedder: embedder,
storage: NewVectorStorage(l, s), storage: NewVectorStorage(l, s),
fallbackMsg: fallbackMsg, fallbackMsg: fallbackMsg,
idleTimeout: 30 * time.Second,
} }
// Note: Vector tables are created via database migrations, not at runtime // Note: Vector tables are created via database migrations, not at runtime
@@ -187,6 +191,7 @@ func (r *RAG) LoadRAG(fpath string) error {
} }
} }
r.logger.Debug("finished writing vectors", "batches", batchCount) r.logger.Debug("finished writing vectors", "batches", batchCount)
r.resetIdleTimer()
select { select {
case LongJobStatusCh <- FinishedRAGStatus: case LongJobStatusCh <- FinishedRAGStatus:
default: default:
@@ -196,10 +201,12 @@ func (r *RAG) LoadRAG(fpath string) error {
} }
func (r *RAG) LineToVector(line string) ([]float32, error) { func (r *RAG) LineToVector(line string) ([]float32, error) {
r.resetIdleTimer()
return r.embedder.Embed(line) return r.embedder.Embed(line)
} }
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
r.resetIdleTimer()
return r.storage.SearchClosest(emb.Embedding) return r.storage.SearchClosest(emb.Embedding)
} }
@@ -208,6 +215,7 @@ func (r *RAG) ListLoaded() ([]string, error) {
} }
func (r *RAG) RemoveFile(filename string) error { func (r *RAG) RemoveFile(filename string) error {
r.resetIdleTimer()
return r.storage.RemoveEmbByFileName(filename) return r.storage.RemoveEmbByFileName(filename)
} }
@@ -471,3 +479,38 @@ func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
func GetInstance() *RAG { func GetInstance() *RAG {
return ragInstance return ragInstance
} }
func (r *RAG) resetIdleTimer() {
if r.idleTimer != nil {
r.idleTimer.Stop()
}
r.idleTimer = time.AfterFunc(r.idleTimeout, func() {
r.freeONNXMemory()
})
}
func (r *RAG) freeONNXMemory() {
r.mu.Lock()
defer r.mu.Unlock()
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
if err := onnx.Destroy(); err != nil {
r.logger.Error("failed to free ONNX memory", "error", err)
} else {
r.logger.Info("freed ONNX VRAM after idle timeout")
}
}
}
func (r *RAG) Destroy() {
r.mu.Lock()
defer r.mu.Unlock()
if r.idleTimer != nil {
r.idleTimer.Stop()
r.idleTimer = nil
}
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
if err := onnx.Destroy(); err != nil {
r.logger.Error("failed to destroy ONNX embedder", "error", err)
}
}
}