From efc92d884c36498220e2b8d5ad9e02f84e42d953 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Thu, 5 Mar 2026 20:02:46 +0300 Subject: [PATCH] Chore: onnx library lookup --- bot.go | 8 +++- rag/embedder.go | 117 +++++++++++++++++++++++++++++++++++++----------- rag/rag.go | 39 +++++++++------- 3 files changed, 122 insertions(+), 42 deletions(-) diff --git a/bot.go b/bot.go index 5463800..20ffeb2 100644 --- a/bot.go +++ b/bot.go @@ -1501,7 +1501,13 @@ func init() { os.Exit(1) return } - ragger = rag.New(logger, store, cfg) + ragger, err = rag.New(logger, store, cfg) + if err != nil { + logger.Error("failed to create RAG", "error", err) + } + if ragger != nil && ragger.FallbackMessage() != "" && app != nil { + showToast("RAG", "ONNX unavailable, using API: "+ragger.FallbackMessage()) + } // https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md // load all chats in memory if _, err := loadHistoryChats(); err != nil { diff --git a/rag/embedder.go b/rag/embedder.go index b0a3226..59dbfd2 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -9,6 +9,7 @@ import ( "gf-lt/models" "log/slog" "net/http" + "os" "sync" "github.com/sugarme/tokenizer" @@ -143,47 +144,111 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { // 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) // 3. Converting text to embeddings without external API calls type ONNXEmbedder struct { - session *onnxruntime_go.DynamicAdvancedSession - tokenizer *tokenizer.Tokenizer - dims int // embedding dimension (e.g., 768) - logger *slog.Logger + session *onnxruntime_go.DynamicAdvancedSession + tokenizer *tokenizer.Tokenizer + tokenizerPath string + dims int + logger *slog.Logger + mu sync.Mutex + modelPath string } var onnxInitOnce sync.Once +var onnxReady bool +var onnxLibPath string + +var onnxLibPaths = []string{ + "/usr/lib/libonnxruntime.so", + "/usr/local/lib/libonnxruntime.so", + "/usr/lib/x86_64-linux-gnu/libonnxruntime.so", + "/opt/onnxruntime/lib/libonnxruntime.so", +} + +func findONNXLibrary() string { + for _, path := range onnxLibPaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + return "" +} func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { - // Initialize ONNX runtime environment once - onnxInitOnce.Do(func() { - onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so") - err := onnxruntime_go.InitializeEnvironment() - if err != nil { - logger.Error("failed to initialize ONNX runtime", "error", err) - } - }) - // Load tokenizer using sugarme/tokenizer - tok, err := pretrained.FromFile(tokenizerPath) - if err != nil { - return nil, fmt.Errorf("failed to load tokenizer: %w", err) + // Check if model and tokenizer files exist + if _, err := os.Stat(modelPath); err != nil { + return nil, fmt.Errorf("ONNX model not found: %w", err) + } + if _, err := os.Stat(tokenizerPath); err != nil { + return nil, fmt.Errorf("tokenizer not found: %w", err) + } + + // Find ONNX library + onnxLibPath = findONNXLibrary() + if onnxLibPath == "" { + return nil, errors.New("ONNX runtime library not found in standard locations") + } + + emb := &ONNXEmbedder{ + tokenizerPath: tokenizerPath, + dims: dims, + logger: logger, + modelPath: modelPath, + } + return emb, nil +} + +func (e *ONNXEmbedder) ensureInitialized() error { + if e.session != nil { + return nil + } + e.mu.Lock() + defer e.mu.Unlock() + if e.session != nil { + return nil + } + + // Load tokenizer lazily + if e.tokenizer == nil { + tok, err := pretrained.FromFile(e.tokenizerPath) + if err != nil { + return fmt.Errorf("failed to load tokenizer: %w", err) + } + e.tokenizer = tok + } + + onnxInitOnce.Do(func() { + onnxruntime_go.SetSharedLibraryPath(onnxLibPath) + if err := onnxruntime_go.InitializeEnvironment(); err != nil { + e.logger.Error("failed to initialize ONNX runtime", "error", err) + onnxReady = false + return + } + onnxReady = true + }) + if !onnxReady { + return errors.New("ONNX runtime not ready") } - // Create ONNX session session, err := onnxruntime_go.NewDynamicAdvancedSession( - modelPath, // onnx/embedgemma/model_q4.onnx + e.getModelPath(), []string{"input_ids", "attention_mask"}, []string{"sentence_embedding"}, - nil, // optional options + nil, ) if err != nil { - return nil, fmt.Errorf("failed to create ONNX session: %w", err) + return fmt.Errorf("failed to create ONNX session: %w", err) } - return &ONNXEmbedder{ - session: session, - tokenizer: tok, - dims: dims, - logger: logger, - }, nil + e.session = session + return nil +} + +func (e *ONNXEmbedder) getModelPath() string { + return e.modelPath } func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { + if err := e.ensureInitialized(); err != nil { + return nil, err + } // 1. Tokenize encoding, err := e.tokenizer.EncodeSingle(text) if err != nil { diff --git a/rag/rag.go b/rag/rag.go index 654afde..fa30303 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -25,20 +25,23 @@ var ( ) type RAG struct { - logger *slog.Logger - store storage.FullRepo - cfg *config.Config - embedder Embedder - storage *VectorStorage - mu sync.Mutex + logger *slog.Logger + store storage.FullRepo + cfg *config.Config + embedder Embedder + storage *VectorStorage + mu sync.Mutex + fallbackMsg string } -func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) { var embedder Embedder + var fallbackMsg string if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" { emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l) if err != nil { l.Error("failed to create ONNX embedder, falling back to API", "error", err) + fallbackMsg = err.Error() embedder = NewAPIEmbedder(l, cfg) } else { embedder = emb @@ -49,16 +52,17 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { l.Info("using API embedder", "url", cfg.EmbedURL) } rag := &RAG{ - logger: l, - store: s, - cfg: cfg, - embedder: embedder, - storage: NewVectorStorage(l, s), + logger: l, + store: s, + cfg: cfg, + embedder: embedder, + storage: NewVectorStorage(l, s), + fallbackMsg: fallbackMsg, } // Note: Vector tables are created via database migrations, not at runtime - return rag + return rag, nil } func wordCounter(sentence string) int { @@ -449,14 +453,19 @@ var ( ragOnce sync.Once ) +func (r *RAG) FallbackMessage() string { + return r.fallbackMsg +} + func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error { + var err error ragOnce.Do(func() { if c == nil || l == nil || s == nil { return } - ragInstance = New(l, s, c) + ragInstance, err = New(l, s, c) }) - return nil + return err } func GetInstance() *RAG {