Chore: onnx library lookup

This commit is contained in:
Grail Finder
2026-03-05 20:02:46 +03:00
parent ac8c8bb055
commit efc92d884c
3 changed files with 122 additions and 42 deletions

View File

@@ -9,6 +9,7 @@ import (
"gf-lt/models"
"log/slog"
"net/http"
"os"
"sync"
"github.com/sugarme/tokenizer"
@@ -143,47 +144,111 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
type ONNXEmbedder struct {
session *onnxruntime_go.DynamicAdvancedSession
tokenizer *tokenizer.Tokenizer
dims int // embedding dimension (e.g., 768)
logger *slog.Logger
session *onnxruntime_go.DynamicAdvancedSession
tokenizer *tokenizer.Tokenizer
tokenizerPath string
dims int
logger *slog.Logger
mu sync.Mutex
modelPath string
}
var onnxInitOnce sync.Once
var onnxReady bool
var onnxLibPath string
var onnxLibPaths = []string{
"/usr/lib/libonnxruntime.so",
"/usr/local/lib/libonnxruntime.so",
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
"/opt/onnxruntime/lib/libonnxruntime.so",
}
func findONNXLibrary() string {
for _, path := range onnxLibPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
// Initialize ONNX runtime environment once
onnxInitOnce.Do(func() {
onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
err := onnxruntime_go.InitializeEnvironment()
if err != nil {
logger.Error("failed to initialize ONNX runtime", "error", err)
}
})
// Load tokenizer using sugarme/tokenizer
tok, err := pretrained.FromFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
// Check if model and tokenizer files exist
if _, err := os.Stat(modelPath); err != nil {
return nil, fmt.Errorf("ONNX model not found: %w", err)
}
if _, err := os.Stat(tokenizerPath); err != nil {
return nil, fmt.Errorf("tokenizer not found: %w", err)
}
// Find ONNX library
onnxLibPath = findONNXLibrary()
if onnxLibPath == "" {
return nil, errors.New("ONNX runtime library not found in standard locations")
}
emb := &ONNXEmbedder{
tokenizerPath: tokenizerPath,
dims: dims,
logger: logger,
modelPath: modelPath,
}
return emb, nil
}
func (e *ONNXEmbedder) ensureInitialized() error {
if e.session != nil {
return nil
}
e.mu.Lock()
defer e.mu.Unlock()
if e.session != nil {
return nil
}
// Load tokenizer lazily
if e.tokenizer == nil {
tok, err := pretrained.FromFile(e.tokenizerPath)
if err != nil {
return fmt.Errorf("failed to load tokenizer: %w", err)
}
e.tokenizer = tok
}
onnxInitOnce.Do(func() {
onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
e.logger.Error("failed to initialize ONNX runtime", "error", err)
onnxReady = false
return
}
onnxReady = true
})
if !onnxReady {
return errors.New("ONNX runtime not ready")
}
// Create ONNX session
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath, // onnx/embedgemma/model_q4.onnx
e.getModelPath(),
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
nil, // optional options
nil,
)
if err != nil {
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
return fmt.Errorf("failed to create ONNX session: %w", err)
}
return &ONNXEmbedder{
session: session,
tokenizer: tok,
dims: dims,
logger: logger,
}, nil
e.session = session
return nil
}
func (e *ONNXEmbedder) getModelPath() string {
return e.modelPath
}
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err := e.ensureInitialized(); err != nil {
return nil, err
}
// 1. Tokenize
encoding, err := e.tokenizer.EncodeSingle(text)
if err != nil {