Chore: onnx library lookup
This commit is contained in:
8
bot.go
8
bot.go
@@ -1501,7 +1501,13 @@ func init() {
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
ragger = rag.New(logger, store, cfg)
|
||||
ragger, err = rag.New(logger, store, cfg)
|
||||
if err != nil {
|
||||
logger.Error("failed to create RAG", "error", err)
|
||||
}
|
||||
if ragger != nil && ragger.FallbackMessage() != "" && app != nil {
|
||||
showToast("RAG", "ONNX unavailable, using API: "+ragger.FallbackMessage())
|
||||
}
|
||||
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
|
||||
// load all chats in memory
|
||||
if _, err := loadHistoryChats(); err != nil {
|
||||
|
||||
117
rag/embedder.go
117
rag/embedder.go
@@ -9,6 +9,7 @@ import (
|
||||
"gf-lt/models"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sugarme/tokenizer"
|
||||
@@ -143,47 +144,111 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
||||
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
||||
// 3. Converting text to embeddings without external API calls
|
||||
type ONNXEmbedder struct {
|
||||
session *onnxruntime_go.DynamicAdvancedSession
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
dims int // embedding dimension (e.g., 768)
|
||||
logger *slog.Logger
|
||||
session *onnxruntime_go.DynamicAdvancedSession
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
tokenizerPath string
|
||||
dims int
|
||||
logger *slog.Logger
|
||||
mu sync.Mutex
|
||||
modelPath string
|
||||
}
|
||||
|
||||
var onnxInitOnce sync.Once
|
||||
var onnxReady bool
|
||||
var onnxLibPath string
|
||||
|
||||
var onnxLibPaths = []string{
|
||||
"/usr/lib/libonnxruntime.so",
|
||||
"/usr/local/lib/libonnxruntime.so",
|
||||
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
|
||||
"/opt/onnxruntime/lib/libonnxruntime.so",
|
||||
}
|
||||
|
||||
func findONNXLibrary() string {
|
||||
for _, path := range onnxLibPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
||||
// Initialize ONNX runtime environment once
|
||||
onnxInitOnce.Do(func() {
|
||||
onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
|
||||
err := onnxruntime_go.InitializeEnvironment()
|
||||
if err != nil {
|
||||
logger.Error("failed to initialize ONNX runtime", "error", err)
|
||||
}
|
||||
})
|
||||
// Load tokenizer using sugarme/tokenizer
|
||||
tok, err := pretrained.FromFile(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
||||
// Check if model and tokenizer files exist
|
||||
if _, err := os.Stat(modelPath); err != nil {
|
||||
return nil, fmt.Errorf("ONNX model not found: %w", err)
|
||||
}
|
||||
if _, err := os.Stat(tokenizerPath); err != nil {
|
||||
return nil, fmt.Errorf("tokenizer not found: %w", err)
|
||||
}
|
||||
|
||||
// Find ONNX library
|
||||
onnxLibPath = findONNXLibrary()
|
||||
if onnxLibPath == "" {
|
||||
return nil, errors.New("ONNX runtime library not found in standard locations")
|
||||
}
|
||||
|
||||
emb := &ONNXEmbedder{
|
||||
tokenizerPath: tokenizerPath,
|
||||
dims: dims,
|
||||
logger: logger,
|
||||
modelPath: modelPath,
|
||||
}
|
||||
return emb, nil
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) ensureInitialized() error {
|
||||
if e.session != nil {
|
||||
return nil
|
||||
}
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if e.session != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load tokenizer lazily
|
||||
if e.tokenizer == nil {
|
||||
tok, err := pretrained.FromFile(e.tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load tokenizer: %w", err)
|
||||
}
|
||||
e.tokenizer = tok
|
||||
}
|
||||
|
||||
onnxInitOnce.Do(func() {
|
||||
onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
|
||||
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
|
||||
e.logger.Error("failed to initialize ONNX runtime", "error", err)
|
||||
onnxReady = false
|
||||
return
|
||||
}
|
||||
onnxReady = true
|
||||
})
|
||||
if !onnxReady {
|
||||
return errors.New("ONNX runtime not ready")
|
||||
}
|
||||
// Create ONNX session
|
||||
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
||||
modelPath, // onnx/embedgemma/model_q4.onnx
|
||||
e.getModelPath(),
|
||||
[]string{"input_ids", "attention_mask"},
|
||||
[]string{"sentence_embedding"},
|
||||
nil, // optional options
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
|
||||
return fmt.Errorf("failed to create ONNX session: %w", err)
|
||||
}
|
||||
return &ONNXEmbedder{
|
||||
session: session,
|
||||
tokenizer: tok,
|
||||
dims: dims,
|
||||
logger: logger,
|
||||
}, nil
|
||||
e.session = session
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) getModelPath() string {
|
||||
return e.modelPath
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||
if err := e.ensureInitialized(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 1. Tokenize
|
||||
encoding, err := e.tokenizer.EncodeSingle(text)
|
||||
if err != nil {
|
||||
|
||||
39
rag/rag.go
39
rag/rag.go
@@ -25,20 +25,23 @@ var (
|
||||
)
|
||||
|
||||
type RAG struct {
|
||||
logger *slog.Logger
|
||||
store storage.FullRepo
|
||||
cfg *config.Config
|
||||
embedder Embedder
|
||||
storage *VectorStorage
|
||||
mu sync.Mutex
|
||||
logger *slog.Logger
|
||||
store storage.FullRepo
|
||||
cfg *config.Config
|
||||
embedder Embedder
|
||||
storage *VectorStorage
|
||||
mu sync.Mutex
|
||||
fallbackMsg string
|
||||
}
|
||||
|
||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||
var embedder Embedder
|
||||
var fallbackMsg string
|
||||
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
|
||||
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
|
||||
if err != nil {
|
||||
l.Error("failed to create ONNX embedder, falling back to API", "error", err)
|
||||
fallbackMsg = err.Error()
|
||||
embedder = NewAPIEmbedder(l, cfg)
|
||||
} else {
|
||||
embedder = emb
|
||||
@@ -49,16 +52,17 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
||||
l.Info("using API embedder", "url", cfg.EmbedURL)
|
||||
}
|
||||
rag := &RAG{
|
||||
logger: l,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
embedder: embedder,
|
||||
storage: NewVectorStorage(l, s),
|
||||
logger: l,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
embedder: embedder,
|
||||
storage: NewVectorStorage(l, s),
|
||||
fallbackMsg: fallbackMsg,
|
||||
}
|
||||
|
||||
// Note: Vector tables are created via database migrations, not at runtime
|
||||
|
||||
return rag
|
||||
return rag, nil
|
||||
}
|
||||
|
||||
func wordCounter(sentence string) int {
|
||||
@@ -449,14 +453,19 @@ var (
|
||||
ragOnce sync.Once
|
||||
)
|
||||
|
||||
func (r *RAG) FallbackMessage() string {
|
||||
return r.fallbackMsg
|
||||
}
|
||||
|
||||
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
|
||||
var err error
|
||||
ragOnce.Do(func() {
|
||||
if c == nil || l == nil || s == nil {
|
||||
return
|
||||
}
|
||||
ragInstance = New(l, s, c)
|
||||
ragInstance, err = New(l, s, c)
|
||||
})
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func GetInstance() *RAG {
|
||||
|
||||
Reference in New Issue
Block a user