Enha (onnx): unload model if noop for 30s
This commit is contained in:
@@ -308,6 +308,19 @@ func (e *ONNXEmbedder) getModelPath() string {
|
|||||||
return e.modelPath
|
return e.modelPath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *ONNXEmbedder) Destroy() error {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
if e.session != nil {
|
||||||
|
if err := e.session.Destroy(); err != nil {
|
||||||
|
return fmt.Errorf("failed to destroy ONNX session: %w", err)
|
||||||
|
}
|
||||||
|
e.session = nil
|
||||||
|
e.logger.Info("ONNX session destroyed, VRAM freed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||||
if err := e.ensureInitialized(); err != nil {
|
if err := e.ensureInitialized(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
43
rag/rag.go
43
rag/rag.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/neurosnap/sentences/english"
|
"github.com/neurosnap/sentences/english"
|
||||||
)
|
)
|
||||||
@@ -32,6 +33,8 @@ type RAG struct {
|
|||||||
storage *VectorStorage
|
storage *VectorStorage
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
fallbackMsg string
|
fallbackMsg string
|
||||||
|
idleTimer *time.Timer
|
||||||
|
idleTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||||
@@ -58,6 +61,7 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
|||||||
embedder: embedder,
|
embedder: embedder,
|
||||||
storage: NewVectorStorage(l, s),
|
storage: NewVectorStorage(l, s),
|
||||||
fallbackMsg: fallbackMsg,
|
fallbackMsg: fallbackMsg,
|
||||||
|
idleTimeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Vector tables are created via database migrations, not at runtime
|
// Note: Vector tables are created via database migrations, not at runtime
|
||||||
@@ -187,6 +191,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.logger.Debug("finished writing vectors", "batches", batchCount)
|
r.logger.Debug("finished writing vectors", "batches", batchCount)
|
||||||
|
r.resetIdleTimer()
|
||||||
select {
|
select {
|
||||||
case LongJobStatusCh <- FinishedRAGStatus:
|
case LongJobStatusCh <- FinishedRAGStatus:
|
||||||
default:
|
default:
|
||||||
@@ -196,10 +201,12 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||||
|
r.resetIdleTimer()
|
||||||
return r.embedder.Embed(line)
|
return r.embedder.Embed(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
||||||
|
r.resetIdleTimer()
|
||||||
return r.storage.SearchClosest(emb.Embedding)
|
return r.storage.SearchClosest(emb.Embedding)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,6 +215,7 @@ func (r *RAG) ListLoaded() ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) RemoveFile(filename string) error {
|
func (r *RAG) RemoveFile(filename string) error {
|
||||||
|
r.resetIdleTimer()
|
||||||
return r.storage.RemoveEmbByFileName(filename)
|
return r.storage.RemoveEmbByFileName(filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -471,3 +479,38 @@ func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
|
|||||||
func GetInstance() *RAG {
|
func GetInstance() *RAG {
|
||||||
return ragInstance
|
return ragInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RAG) resetIdleTimer() {
|
||||||
|
if r.idleTimer != nil {
|
||||||
|
r.idleTimer.Stop()
|
||||||
|
}
|
||||||
|
r.idleTimer = time.AfterFunc(r.idleTimeout, func() {
|
||||||
|
r.freeONNXMemory()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) freeONNXMemory() {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
|
||||||
|
if err := onnx.Destroy(); err != nil {
|
||||||
|
r.logger.Error("failed to free ONNX memory", "error", err)
|
||||||
|
} else {
|
||||||
|
r.logger.Info("freed ONNX VRAM after idle timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) Destroy() {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if r.idleTimer != nil {
|
||||||
|
r.idleTimer.Stop()
|
||||||
|
r.idleTimer = nil
|
||||||
|
}
|
||||||
|
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
|
||||||
|
if err := onnx.Destroy(); err != nil {
|
||||||
|
r.logger.Error("failed to destroy ONNX embedder", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user