diff --git a/bot.go b/bot.go index 13ee074..5463800 100644 --- a/bot.go +++ b/bot.go @@ -1393,12 +1393,13 @@ func updateModelLists() { } } // if llama.cpp started after gf-lt? - localModelsMu.Lock() - LocalModels, err = fetchLCPModelsWithLoadStatus() - localModelsMu.Unlock() + ml, err := fetchLCPModelsWithLoadStatus() if err != nil { logger.Warn("failed to fetch llama.cpp models", "error", err) } + localModelsMu.Lock() + LocalModels = ml + localModelsMu.Unlock() // set already loaded model in llama.cpp if strings.Contains(cfg.CurrentAPI, "localhost") || strings.Contains(cfg.CurrentAPI, "127.0.0.1") { localModelsMu.Lock() diff --git a/config.example.toml b/config.example.toml index 39a730b..f5820da 100644 --- a/config.example.toml +++ b/config.example.toml @@ -13,6 +13,9 @@ OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions" # embeddings EmbedURL = "http://localhost:8082/v1/embeddings" HFToken = "" +EmbedModelPath = "onnx/embedgemma/model_q4.onnx" +EmbedTokenizerPath = "onnx/embedgemma/tokenizer.json" +EmbedDims = 768 # ShowSys = true LogFile = "log.txt" diff --git a/config/config.go b/config/config.go index 412eaaa..84ec480 100644 --- a/config/config.go +++ b/config/config.go @@ -34,8 +34,11 @@ type Config struct { ImagePreview bool `toml:"ImagePreview"` EnableMouse bool `toml:"EnableMouse"` // embeddings - EmbedURL string `toml:"EmbedURL"` - HFToken string `toml:"HFToken"` + EmbedURL string `toml:"EmbedURL"` + HFToken string `toml:"HFToken"` + EmbedModelPath string `toml:"EmbedModelPath"` + EmbedTokenizerPath string `toml:"EmbedTokenizerPath"` + EmbedDims int `toml:"EmbedDims"` // rag settings RAGEnabled bool `toml:"RAGEnabled"` RAGDir string `toml:"RAGDir"` diff --git a/rag/embedder.go b/rag/embedder.go index 6903a5d..b0a3226 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -9,6 +9,7 @@ import ( "gf-lt/models" "log/slog" "net/http" + "sync" "github.com/sugarme/tokenizer" "github.com/sugarme/tokenizer/pretrained" @@ -148,7 +149,17 @@ type ONNXEmbedder struct { logger *slog.Logger } +var onnxInitOnce sync.Once + func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { + // Initialize ONNX runtime environment once + onnxInitOnce.Do(func() { + onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so") + err := onnxruntime_go.InitializeEnvironment() + if err != nil { + logger.Error("failed to initialize ONNX runtime", "error", err) + } + }) // Load tokenizer using sugarme/tokenizer tok, err := pretrained.FromFile(tokenizerPath) if err != nil { @@ -195,7 +206,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { if err != nil { return nil, fmt.Errorf("failed to create input_ids tensor: %w", err) } - defer inputIDsTensor.Destroy() + defer func() { _ = inputIDsTensor.Destroy() }() maskTensor, err := onnxruntime_go.NewTensor[int64]( onnxruntime_go.NewShape(1, seqLen), attentionMask, @@ -203,7 +214,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { if err != nil { return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err) } - defer maskTensor.Destroy() + defer func() { _ = maskTensor.Destroy() }() // 4. Create output tensor outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( onnxruntime_go.NewShape(1, int64(e.dims)), @@ -211,7 +222,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { if err != nil { return nil, fmt.Errorf("failed to create output tensor: %w", err) } - defer outputTensor.Destroy() + defer func() { _ = outputTensor.Destroy() }() // 5. Run inference err = e.session.Run( []onnxruntime_go.Value{inputIDsTensor, maskTensor}, @@ -257,16 +268,16 @@ func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), inputIDs, ) - defer inputTensor.Destroy() + defer func() { _ = inputTensor.Destroy() }() maskTensor, _ := onnxruntime_go.NewTensor[int64]( onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), attentionMask, ) - defer maskTensor.Destroy() + defer func() { _ = maskTensor.Destroy() }() outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32]( onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)), ) - defer outputTensor.Destroy() + defer func() { _ = outputTensor.Destroy() }() err := e.session.Run( []onnxruntime_go.Value{inputTensor, maskTensor}, []onnxruntime_go.Value{outputTensor}, diff --git a/rag/rag.go b/rag/rag.go index 3d0f38f..654afde 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -34,8 +34,20 @@ type RAG struct { } func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { - // Initialize with API embedder by default, could be configurable later - embedder := NewAPIEmbedder(l, cfg) + var embedder Embedder + if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" { + emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l) + if err != nil { + l.Error("failed to create ONNX embedder, falling back to API", "error", err) + embedder = NewAPIEmbedder(l, cfg) + } else { + embedder = emb + l.Info("using ONNX embedder", "model", cfg.EmbedModelPath, "dims", cfg.EmbedDims) + } + } else { + embedder = NewAPIEmbedder(l, cfg) + l.Info("using API embedder", "url", cfg.EmbedURL) + } rag := &RAG{ logger: l, store: s,