Enha (onnx): unload model if noop for 30s
This commit is contained in:
@@ -308,6 +308,19 @@ func (e *ONNXEmbedder) getModelPath() string {
|
||||
return e.modelPath
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) Destroy() error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if e.session != nil {
|
||||
if err := e.session.Destroy(); err != nil {
|
||||
return fmt.Errorf("failed to destroy ONNX session: %w", err)
|
||||
}
|
||||
e.session = nil
|
||||
e.logger.Info("ONNX session destroyed, VRAM freed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||
if err := e.ensureInitialized(); err != nil {
|
||||
return nil, err
|
||||
|
||||
43
rag/rag.go
43
rag/rag.go
@@ -12,6 +12,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/neurosnap/sentences/english"
|
||||
)
|
||||
@@ -32,6 +33,8 @@ type RAG struct {
|
||||
storage *VectorStorage
|
||||
mu sync.Mutex
|
||||
fallbackMsg string
|
||||
idleTimer *time.Timer
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||
@@ -58,6 +61,7 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||
embedder: embedder,
|
||||
storage: NewVectorStorage(l, s),
|
||||
fallbackMsg: fallbackMsg,
|
||||
idleTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Note: Vector tables are created via database migrations, not at runtime
|
||||
@@ -187,6 +191,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
}
|
||||
}
|
||||
r.logger.Debug("finished writing vectors", "batches", batchCount)
|
||||
r.resetIdleTimer()
|
||||
select {
|
||||
case LongJobStatusCh <- FinishedRAGStatus:
|
||||
default:
|
||||
@@ -196,10 +201,12 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
}
|
||||
|
||||
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||
r.resetIdleTimer()
|
||||
return r.embedder.Embed(line)
|
||||
}
|
||||
|
||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
||||
r.resetIdleTimer()
|
||||
return r.storage.SearchClosest(emb.Embedding)
|
||||
}
|
||||
|
||||
@@ -208,6 +215,7 @@ func (r *RAG) ListLoaded() ([]string, error) {
|
||||
}
|
||||
|
||||
func (r *RAG) RemoveFile(filename string) error {
|
||||
r.resetIdleTimer()
|
||||
return r.storage.RemoveEmbByFileName(filename)
|
||||
}
|
||||
|
||||
@@ -471,3 +479,38 @@ func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
|
||||
func GetInstance() *RAG {
|
||||
return ragInstance
|
||||
}
|
||||
|
||||
func (r *RAG) resetIdleTimer() {
|
||||
if r.idleTimer != nil {
|
||||
r.idleTimer.Stop()
|
||||
}
|
||||
r.idleTimer = time.AfterFunc(r.idleTimeout, func() {
|
||||
r.freeONNXMemory()
|
||||
})
|
||||
}
|
||||
|
||||
func (r *RAG) freeONNXMemory() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
|
||||
if err := onnx.Destroy(); err != nil {
|
||||
r.logger.Error("failed to free ONNX memory", "error", err)
|
||||
} else {
|
||||
r.logger.Info("freed ONNX VRAM after idle timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RAG) Destroy() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.idleTimer != nil {
|
||||
r.idleTimer.Stop()
|
||||
r.idleTimer = nil
|
||||
}
|
||||
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
|
||||
if err := onnx.Destroy(); err != nil {
|
||||
r.logger.Error("failed to destroy ONNX embedder", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user