Chore: onnx library lookup

This commit is contained in:
Grail Finder
2026-03-05 20:02:46 +03:00
parent ac8c8bb055
commit efc92d884c
3 changed files with 122 additions and 42 deletions

8
bot.go
View File

@@ -1501,7 +1501,13 @@ func init() {
os.Exit(1)
return
}
ragger = rag.New(logger, store, cfg)
ragger, err = rag.New(logger, store, cfg)
if err != nil {
logger.Error("failed to create RAG", "error", err)
}
if ragger != nil && ragger.FallbackMessage() != "" && app != nil {
showToast("RAG", "ONNX unavailable, using API: "+ragger.FallbackMessage())
}
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
// load all chats in memory
if _, err := loadHistoryChats(); err != nil {

View File

@@ -9,6 +9,7 @@ import (
"gf-lt/models"
"log/slog"
"net/http"
"os"
"sync"
"github.com/sugarme/tokenizer"
@@ -143,47 +144,111 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
type ONNXEmbedder struct {
session *onnxruntime_go.DynamicAdvancedSession
tokenizer *tokenizer.Tokenizer
dims int // embedding dimension (e.g., 768)
logger *slog.Logger
session *onnxruntime_go.DynamicAdvancedSession
tokenizer *tokenizer.Tokenizer
tokenizerPath string
dims int
logger *slog.Logger
mu sync.Mutex
modelPath string
}
var onnxInitOnce sync.Once
var onnxReady bool
var onnxLibPath string
var onnxLibPaths = []string{
"/usr/lib/libonnxruntime.so",
"/usr/local/lib/libonnxruntime.so",
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
"/opt/onnxruntime/lib/libonnxruntime.so",
}
func findONNXLibrary() string {
for _, path := range onnxLibPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
// Initialize ONNX runtime environment once
onnxInitOnce.Do(func() {
onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
err := onnxruntime_go.InitializeEnvironment()
if err != nil {
logger.Error("failed to initialize ONNX runtime", "error", err)
}
})
// Load tokenizer using sugarme/tokenizer
tok, err := pretrained.FromFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
// Check if model and tokenizer files exist
if _, err := os.Stat(modelPath); err != nil {
return nil, fmt.Errorf("ONNX model not found: %w", err)
}
if _, err := os.Stat(tokenizerPath); err != nil {
return nil, fmt.Errorf("tokenizer not found: %w", err)
}
// Find ONNX library
onnxLibPath = findONNXLibrary()
if onnxLibPath == "" {
return nil, errors.New("ONNX runtime library not found in standard locations")
}
emb := &ONNXEmbedder{
tokenizerPath: tokenizerPath,
dims: dims,
logger: logger,
modelPath: modelPath,
}
return emb, nil
}
func (e *ONNXEmbedder) ensureInitialized() error {
if e.session != nil {
return nil
}
e.mu.Lock()
defer e.mu.Unlock()
if e.session != nil {
return nil
}
// Load tokenizer lazily
if e.tokenizer == nil {
tok, err := pretrained.FromFile(e.tokenizerPath)
if err != nil {
return fmt.Errorf("failed to load tokenizer: %w", err)
}
e.tokenizer = tok
}
onnxInitOnce.Do(func() {
onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
e.logger.Error("failed to initialize ONNX runtime", "error", err)
onnxReady = false
return
}
onnxReady = true
})
if !onnxReady {
return errors.New("ONNX runtime not ready")
}
// Create ONNX session
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath, // onnx/embedgemma/model_q4.onnx
e.getModelPath(),
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
nil, // optional options
nil,
)
if err != nil {
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
return fmt.Errorf("failed to create ONNX session: %w", err)
}
return &ONNXEmbedder{
session: session,
tokenizer: tok,
dims: dims,
logger: logger,
}, nil
e.session = session
return nil
}
func (e *ONNXEmbedder) getModelPath() string {
return e.modelPath
}
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err := e.ensureInitialized(); err != nil {
return nil, err
}
// 1. Tokenize
encoding, err := e.tokenizer.EncodeSingle(text)
if err != nil {

View File

@@ -25,20 +25,23 @@ var (
)
type RAG struct {
logger *slog.Logger
store storage.FullRepo
cfg *config.Config
embedder Embedder
storage *VectorStorage
mu sync.Mutex
logger *slog.Logger
store storage.FullRepo
cfg *config.Config
embedder Embedder
storage *VectorStorage
mu sync.Mutex
fallbackMsg string
}
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
var embedder Embedder
var fallbackMsg string
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
if err != nil {
l.Error("failed to create ONNX embedder, falling back to API", "error", err)
fallbackMsg = err.Error()
embedder = NewAPIEmbedder(l, cfg)
} else {
embedder = emb
@@ -49,16 +52,17 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
l.Info("using API embedder", "url", cfg.EmbedURL)
}
rag := &RAG{
logger: l,
store: s,
cfg: cfg,
embedder: embedder,
storage: NewVectorStorage(l, s),
logger: l,
store: s,
cfg: cfg,
embedder: embedder,
storage: NewVectorStorage(l, s),
fallbackMsg: fallbackMsg,
}
// Note: Vector tables are created via database migrations, not at runtime
return rag
return rag, nil
}
func wordCounter(sentence string) int {
@@ -449,14 +453,19 @@ var (
ragOnce sync.Once
)
func (r *RAG) FallbackMessage() string {
return r.fallbackMsg
}
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
var err error
ragOnce.Do(func() {
if c == nil || l == nil || s == nil {
return
}
ragInstance = New(l, s, c)
ragInstance, err = New(l, s, c)
})
return nil
return err
}
func GetInstance() *RAG {