Chore: onnx library lookup
This commit is contained in:
8
bot.go
8
bot.go
@@ -1501,7 +1501,13 @@ func init() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ragger = rag.New(logger, store, cfg)
|
ragger, err = rag.New(logger, store, cfg)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to create RAG", "error", err)
|
||||||
|
}
|
||||||
|
if ragger != nil && ragger.FallbackMessage() != "" && app != nil {
|
||||||
|
showToast("RAG", "ONNX unavailable, using API: "+ragger.FallbackMessage())
|
||||||
|
}
|
||||||
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
|
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
|
||||||
// load all chats in memory
|
// load all chats in memory
|
||||||
if _, err := loadHistoryChats(); err != nil {
|
if _, err := loadHistoryChats(); err != nil {
|
||||||
|
|||||||
115
rag/embedder.go
115
rag/embedder.go
@@ -9,6 +9,7 @@ import (
|
|||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sugarme/tokenizer"
|
"github.com/sugarme/tokenizer"
|
||||||
@@ -145,45 +146,109 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
|||||||
type ONNXEmbedder struct {
|
type ONNXEmbedder struct {
|
||||||
session *onnxruntime_go.DynamicAdvancedSession
|
session *onnxruntime_go.DynamicAdvancedSession
|
||||||
tokenizer *tokenizer.Tokenizer
|
tokenizer *tokenizer.Tokenizer
|
||||||
dims int // embedding dimension (e.g., 768)
|
tokenizerPath string
|
||||||
|
dims int
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
mu sync.Mutex
|
||||||
|
modelPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
var onnxInitOnce sync.Once
|
var onnxInitOnce sync.Once
|
||||||
|
var onnxReady bool
|
||||||
|
var onnxLibPath string
|
||||||
|
|
||||||
|
var onnxLibPaths = []string{
|
||||||
|
"/usr/lib/libonnxruntime.so",
|
||||||
|
"/usr/local/lib/libonnxruntime.so",
|
||||||
|
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
|
||||||
|
"/opt/onnxruntime/lib/libonnxruntime.so",
|
||||||
|
}
|
||||||
|
|
||||||
|
func findONNXLibrary() string {
|
||||||
|
for _, path := range onnxLibPaths {
|
||||||
|
if _, err := os.Stat(path); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
||||||
// Initialize ONNX runtime environment once
|
// Check if model and tokenizer files exist
|
||||||
onnxInitOnce.Do(func() {
|
if _, err := os.Stat(modelPath); err != nil {
|
||||||
onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
|
return nil, fmt.Errorf("ONNX model not found: %w", err)
|
||||||
err := onnxruntime_go.InitializeEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to initialize ONNX runtime", "error", err)
|
|
||||||
}
|
}
|
||||||
})
|
if _, err := os.Stat(tokenizerPath); err != nil {
|
||||||
// Load tokenizer using sugarme/tokenizer
|
return nil, fmt.Errorf("tokenizer not found: %w", err)
|
||||||
tok, err := pretrained.FromFile(tokenizerPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
|
||||||
}
|
}
|
||||||
// Create ONNX session
|
|
||||||
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
// Find ONNX library
|
||||||
modelPath, // onnx/embedgemma/model_q4.onnx
|
onnxLibPath = findONNXLibrary()
|
||||||
[]string{"input_ids", "attention_mask"},
|
if onnxLibPath == "" {
|
||||||
[]string{"sentence_embedding"},
|
return nil, errors.New("ONNX runtime library not found in standard locations")
|
||||||
nil, // optional options
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
|
|
||||||
}
|
}
|
||||||
return &ONNXEmbedder{
|
|
||||||
session: session,
|
emb := &ONNXEmbedder{
|
||||||
tokenizer: tok,
|
tokenizerPath: tokenizerPath,
|
||||||
dims: dims,
|
dims: dims,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}, nil
|
modelPath: modelPath,
|
||||||
|
}
|
||||||
|
return emb, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ONNXEmbedder) ensureInitialized() error {
|
||||||
|
if e.session != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
if e.session != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load tokenizer lazily
|
||||||
|
if e.tokenizer == nil {
|
||||||
|
tok, err := pretrained.FromFile(e.tokenizerPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
e.tokenizer = tok
|
||||||
|
}
|
||||||
|
|
||||||
|
onnxInitOnce.Do(func() {
|
||||||
|
onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
|
||||||
|
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
|
||||||
|
e.logger.Error("failed to initialize ONNX runtime", "error", err)
|
||||||
|
onnxReady = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onnxReady = true
|
||||||
|
})
|
||||||
|
if !onnxReady {
|
||||||
|
return errors.New("ONNX runtime not ready")
|
||||||
|
}
|
||||||
|
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
||||||
|
e.getModelPath(),
|
||||||
|
[]string{"input_ids", "attention_mask"},
|
||||||
|
[]string{"sentence_embedding"},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create ONNX session: %w", err)
|
||||||
|
}
|
||||||
|
e.session = session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ONNXEmbedder) getModelPath() string {
|
||||||
|
return e.modelPath
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||||
|
if err := e.ensureInitialized(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
// 1. Tokenize
|
// 1. Tokenize
|
||||||
encoding, err := e.tokenizer.EncodeSingle(text)
|
encoding, err := e.tokenizer.EncodeSingle(text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
17
rag/rag.go
17
rag/rag.go
@@ -31,14 +31,17 @@ type RAG struct {
|
|||||||
embedder Embedder
|
embedder Embedder
|
||||||
storage *VectorStorage
|
storage *VectorStorage
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
fallbackMsg string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||||
var embedder Embedder
|
var embedder Embedder
|
||||||
|
var fallbackMsg string
|
||||||
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
|
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
|
||||||
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
|
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Error("failed to create ONNX embedder, falling back to API", "error", err)
|
l.Error("failed to create ONNX embedder, falling back to API", "error", err)
|
||||||
|
fallbackMsg = err.Error()
|
||||||
embedder = NewAPIEmbedder(l, cfg)
|
embedder = NewAPIEmbedder(l, cfg)
|
||||||
} else {
|
} else {
|
||||||
embedder = emb
|
embedder = emb
|
||||||
@@ -54,11 +57,12 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
embedder: embedder,
|
embedder: embedder,
|
||||||
storage: NewVectorStorage(l, s),
|
storage: NewVectorStorage(l, s),
|
||||||
|
fallbackMsg: fallbackMsg,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Vector tables are created via database migrations, not at runtime
|
// Note: Vector tables are created via database migrations, not at runtime
|
||||||
|
|
||||||
return rag
|
return rag, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func wordCounter(sentence string) int {
|
func wordCounter(sentence string) int {
|
||||||
@@ -449,14 +453,19 @@ var (
|
|||||||
ragOnce sync.Once
|
ragOnce sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (r *RAG) FallbackMessage() string {
|
||||||
|
return r.fallbackMsg
|
||||||
|
}
|
||||||
|
|
||||||
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
|
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
|
||||||
|
var err error
|
||||||
ragOnce.Do(func() {
|
ragOnce.Do(func() {
|
||||||
if c == nil || l == nil || s == nil {
|
if c == nil || l == nil || s == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ragInstance = New(l, s, c)
|
ragInstance, err = New(l, s, c)
|
||||||
})
|
})
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetInstance() *RAG {
|
func GetInstance() *RAG {
|
||||||
|
|||||||
Reference in New Issue
Block a user