Chore: onnx library lookup

This commit is contained in:
Grail Finder
2026-03-05 20:02:46 +03:00
parent ac8c8bb055
commit efc92d884c
3 changed files with 122 additions and 42 deletions

8
bot.go
View File

@@ -1501,7 +1501,13 @@ func init() {
os.Exit(1) os.Exit(1)
return return
} }
ragger = rag.New(logger, store, cfg) ragger, err = rag.New(logger, store, cfg)
if err != nil {
logger.Error("failed to create RAG", "error", err)
}
if ragger != nil && ragger.FallbackMessage() != "" && app != nil {
showToast("RAG", "ONNX unavailable, using API: "+ragger.FallbackMessage())
}
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md // https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
// load all chats in memory // load all chats in memory
if _, err := loadHistoryChats(); err != nil { if _, err := loadHistoryChats(); err != nil {

View File

@@ -9,6 +9,7 @@ import (
"gf-lt/models" "gf-lt/models"
"log/slog" "log/slog"
"net/http" "net/http"
"os"
"sync" "sync"
"github.com/sugarme/tokenizer" "github.com/sugarme/tokenizer"
@@ -145,45 +146,109 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
type ONNXEmbedder struct { type ONNXEmbedder struct {
session *onnxruntime_go.DynamicAdvancedSession session *onnxruntime_go.DynamicAdvancedSession
tokenizer *tokenizer.Tokenizer tokenizer *tokenizer.Tokenizer
dims int // embedding dimension (e.g., 768) tokenizerPath string
dims int
logger *slog.Logger logger *slog.Logger
mu sync.Mutex
modelPath string
} }
var onnxInitOnce sync.Once var onnxInitOnce sync.Once
var onnxReady bool
var onnxLibPath string
var onnxLibPaths = []string{
"/usr/lib/libonnxruntime.so",
"/usr/local/lib/libonnxruntime.so",
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
"/opt/onnxruntime/lib/libonnxruntime.so",
}
func findONNXLibrary() string {
for _, path := range onnxLibPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
// Initialize ONNX runtime environment once // Check if model and tokenizer files exist
onnxInitOnce.Do(func() { if _, err := os.Stat(modelPath); err != nil {
onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so") return nil, fmt.Errorf("ONNX model not found: %w", err)
err := onnxruntime_go.InitializeEnvironment()
if err != nil {
logger.Error("failed to initialize ONNX runtime", "error", err)
} }
}) if _, err := os.Stat(tokenizerPath); err != nil {
// Load tokenizer using sugarme/tokenizer return nil, fmt.Errorf("tokenizer not found: %w", err)
tok, err := pretrained.FromFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
} }
// Create ONNX session
session, err := onnxruntime_go.NewDynamicAdvancedSession( // Find ONNX library
modelPath, // onnx/embedgemma/model_q4.onnx onnxLibPath = findONNXLibrary()
[]string{"input_ids", "attention_mask"}, if onnxLibPath == "" {
[]string{"sentence_embedding"}, return nil, errors.New("ONNX runtime library not found in standard locations")
nil, // optional options
)
if err != nil {
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
} }
return &ONNXEmbedder{
session: session, emb := &ONNXEmbedder{
tokenizer: tok, tokenizerPath: tokenizerPath,
dims: dims, dims: dims,
logger: logger, logger: logger,
}, nil modelPath: modelPath,
}
return emb, nil
}
func (e *ONNXEmbedder) ensureInitialized() error {
if e.session != nil {
return nil
}
e.mu.Lock()
defer e.mu.Unlock()
if e.session != nil {
return nil
}
// Load tokenizer lazily
if e.tokenizer == nil {
tok, err := pretrained.FromFile(e.tokenizerPath)
if err != nil {
return fmt.Errorf("failed to load tokenizer: %w", err)
}
e.tokenizer = tok
}
onnxInitOnce.Do(func() {
onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
e.logger.Error("failed to initialize ONNX runtime", "error", err)
onnxReady = false
return
}
onnxReady = true
})
if !onnxReady {
return errors.New("ONNX runtime not ready")
}
session, err := onnxruntime_go.NewDynamicAdvancedSession(
e.getModelPath(),
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
nil,
)
if err != nil {
return fmt.Errorf("failed to create ONNX session: %w", err)
}
e.session = session
return nil
}
func (e *ONNXEmbedder) getModelPath() string {
return e.modelPath
} }
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err := e.ensureInitialized(); err != nil {
return nil, err
}
// 1. Tokenize // 1. Tokenize
encoding, err := e.tokenizer.EncodeSingle(text) encoding, err := e.tokenizer.EncodeSingle(text)
if err != nil { if err != nil {

View File

@@ -31,14 +31,17 @@ type RAG struct {
embedder Embedder embedder Embedder
storage *VectorStorage storage *VectorStorage
mu sync.Mutex mu sync.Mutex
fallbackMsg string
} }
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
var embedder Embedder var embedder Embedder
var fallbackMsg string
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" { if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l) emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
if err != nil { if err != nil {
l.Error("failed to create ONNX embedder, falling back to API", "error", err) l.Error("failed to create ONNX embedder, falling back to API", "error", err)
fallbackMsg = err.Error()
embedder = NewAPIEmbedder(l, cfg) embedder = NewAPIEmbedder(l, cfg)
} else { } else {
embedder = emb embedder = emb
@@ -54,11 +57,12 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
cfg: cfg, cfg: cfg,
embedder: embedder, embedder: embedder,
storage: NewVectorStorage(l, s), storage: NewVectorStorage(l, s),
fallbackMsg: fallbackMsg,
} }
// Note: Vector tables are created via database migrations, not at runtime // Note: Vector tables are created via database migrations, not at runtime
return rag return rag, nil
} }
func wordCounter(sentence string) int { func wordCounter(sentence string) int {
@@ -449,14 +453,19 @@ var (
ragOnce sync.Once ragOnce sync.Once
) )
func (r *RAG) FallbackMessage() string {
return r.fallbackMsg
}
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error { func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
var err error
ragOnce.Do(func() { ragOnce.Do(func() {
if c == nil || l == nil || s == nil { if c == nil || l == nil || s == nil {
return return
} }
ragInstance = New(l, s, c) ragInstance, err = New(l, s, c)
}) })
return nil return err
} }
func GetInstance() *RAG { func GetInstance() *RAG {