Enha: onnx config vars
This commit is contained in:
7
bot.go
7
bot.go
@@ -1393,12 +1393,13 @@ func updateModelLists() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if llama.cpp started after gf-lt?
|
// if llama.cpp started after gf-lt?
|
||||||
localModelsMu.Lock()
|
ml, err := fetchLCPModelsWithLoadStatus()
|
||||||
LocalModels, err = fetchLCPModelsWithLoadStatus()
|
|
||||||
localModelsMu.Unlock()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("failed to fetch llama.cpp models", "error", err)
|
logger.Warn("failed to fetch llama.cpp models", "error", err)
|
||||||
}
|
}
|
||||||
|
localModelsMu.Lock()
|
||||||
|
LocalModels = ml
|
||||||
|
localModelsMu.Unlock()
|
||||||
// set already loaded model in llama.cpp
|
// set already loaded model in llama.cpp
|
||||||
if strings.Contains(cfg.CurrentAPI, "localhost") || strings.Contains(cfg.CurrentAPI, "127.0.0.1") {
|
if strings.Contains(cfg.CurrentAPI, "localhost") || strings.Contains(cfg.CurrentAPI, "127.0.0.1") {
|
||||||
localModelsMu.Lock()
|
localModelsMu.Lock()
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions"
|
|||||||
# embeddings
|
# embeddings
|
||||||
EmbedURL = "http://localhost:8082/v1/embeddings"
|
EmbedURL = "http://localhost:8082/v1/embeddings"
|
||||||
HFToken = ""
|
HFToken = ""
|
||||||
|
EmbedModelPath = "onnx/embedgemma/model_q4.onnx"
|
||||||
|
EmbedTokenizerPath = "onnx/embedgemma/tokenizer.json"
|
||||||
|
EmbedDims = 768
|
||||||
#
|
#
|
||||||
ShowSys = true
|
ShowSys = true
|
||||||
LogFile = "log.txt"
|
LogFile = "log.txt"
|
||||||
|
|||||||
@@ -34,8 +34,11 @@ type Config struct {
|
|||||||
ImagePreview bool `toml:"ImagePreview"`
|
ImagePreview bool `toml:"ImagePreview"`
|
||||||
EnableMouse bool `toml:"EnableMouse"`
|
EnableMouse bool `toml:"EnableMouse"`
|
||||||
// embeddings
|
// embeddings
|
||||||
EmbedURL string `toml:"EmbedURL"`
|
EmbedURL string `toml:"EmbedURL"`
|
||||||
HFToken string `toml:"HFToken"`
|
HFToken string `toml:"HFToken"`
|
||||||
|
EmbedModelPath string `toml:"EmbedModelPath"`
|
||||||
|
EmbedTokenizerPath string `toml:"EmbedTokenizerPath"`
|
||||||
|
EmbedDims int `toml:"EmbedDims"`
|
||||||
// rag settings
|
// rag settings
|
||||||
RAGEnabled bool `toml:"RAGEnabled"`
|
RAGEnabled bool `toml:"RAGEnabled"`
|
||||||
RAGDir string `toml:"RAGDir"`
|
RAGDir string `toml:"RAGDir"`
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/sugarme/tokenizer"
|
"github.com/sugarme/tokenizer"
|
||||||
"github.com/sugarme/tokenizer/pretrained"
|
"github.com/sugarme/tokenizer/pretrained"
|
||||||
@@ -148,7 +149,17 @@ type ONNXEmbedder struct {
|
|||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var onnxInitOnce sync.Once
|
||||||
|
|
||||||
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
||||||
|
// Initialize ONNX runtime environment once
|
||||||
|
onnxInitOnce.Do(func() {
|
||||||
|
onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
|
||||||
|
err := onnxruntime_go.InitializeEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to initialize ONNX runtime", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
// Load tokenizer using sugarme/tokenizer
|
// Load tokenizer using sugarme/tokenizer
|
||||||
tok, err := pretrained.FromFile(tokenizerPath)
|
tok, err := pretrained.FromFile(tokenizerPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -195,7 +206,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
||||||
}
|
}
|
||||||
defer inputIDsTensor.Destroy()
|
defer func() { _ = inputIDsTensor.Destroy() }()
|
||||||
maskTensor, err := onnxruntime_go.NewTensor[int64](
|
maskTensor, err := onnxruntime_go.NewTensor[int64](
|
||||||
onnxruntime_go.NewShape(1, seqLen),
|
onnxruntime_go.NewShape(1, seqLen),
|
||||||
attentionMask,
|
attentionMask,
|
||||||
@@ -203,7 +214,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
||||||
}
|
}
|
||||||
defer maskTensor.Destroy()
|
defer func() { _ = maskTensor.Destroy() }()
|
||||||
// 4. Create output tensor
|
// 4. Create output tensor
|
||||||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
||||||
onnxruntime_go.NewShape(1, int64(e.dims)),
|
onnxruntime_go.NewShape(1, int64(e.dims)),
|
||||||
@@ -211,7 +222,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
||||||
}
|
}
|
||||||
defer outputTensor.Destroy()
|
defer func() { _ = outputTensor.Destroy() }()
|
||||||
// 5. Run inference
|
// 5. Run inference
|
||||||
err = e.session.Run(
|
err = e.session.Run(
|
||||||
[]onnxruntime_go.Value{inputIDsTensor, maskTensor},
|
[]onnxruntime_go.Value{inputIDsTensor, maskTensor},
|
||||||
@@ -257,16 +268,16 @@ func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
|||||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||||
inputIDs,
|
inputIDs,
|
||||||
)
|
)
|
||||||
defer inputTensor.Destroy()
|
defer func() { _ = inputTensor.Destroy() }()
|
||||||
maskTensor, _ := onnxruntime_go.NewTensor[int64](
|
maskTensor, _ := onnxruntime_go.NewTensor[int64](
|
||||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||||
attentionMask,
|
attentionMask,
|
||||||
)
|
)
|
||||||
defer maskTensor.Destroy()
|
defer func() { _ = maskTensor.Destroy() }()
|
||||||
outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
|
outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
|
||||||
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
||||||
)
|
)
|
||||||
defer outputTensor.Destroy()
|
defer func() { _ = outputTensor.Destroy() }()
|
||||||
err := e.session.Run(
|
err := e.session.Run(
|
||||||
[]onnxruntime_go.Value{inputTensor, maskTensor},
|
[]onnxruntime_go.Value{inputTensor, maskTensor},
|
||||||
[]onnxruntime_go.Value{outputTensor},
|
[]onnxruntime_go.Value{outputTensor},
|
||||||
|
|||||||
16
rag/rag.go
16
rag/rag.go
@@ -34,8 +34,20 @@ type RAG struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
||||||
// Initialize with API embedder by default, could be configurable later
|
var embedder Embedder
|
||||||
embedder := NewAPIEmbedder(l, cfg)
|
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
|
||||||
|
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
|
||||||
|
if err != nil {
|
||||||
|
l.Error("failed to create ONNX embedder, falling back to API", "error", err)
|
||||||
|
embedder = NewAPIEmbedder(l, cfg)
|
||||||
|
} else {
|
||||||
|
embedder = emb
|
||||||
|
l.Info("using ONNX embedder", "model", cfg.EmbedModelPath, "dims", cfg.EmbedDims)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
embedder = NewAPIEmbedder(l, cfg)
|
||||||
|
l.Info("using API embedder", "url", cfg.EmbedURL)
|
||||||
|
}
|
||||||
rag := &RAG{
|
rag := &RAG{
|
||||||
logger: l,
|
logger: l,
|
||||||
store: s,
|
store: s,
|
||||||
|
|||||||
Reference in New Issue
Block a user