Enha: local onnx

This commit is contained in:
Grail Finder
2026-03-05 14:13:58 +03:00
parent c65c11bcfb
commit fbc955ca37
5 changed files with 172 additions and 7 deletions

View File

@@ -9,6 +9,10 @@ import (
"gf-lt/models"
"log/slog"
"net/http"
"github.com/takara-ai/go-tokenizers/tokenizers"
"github.com/yalue/onnxruntime_go"
)
// Embedder defines the interface for embedding text
@@ -134,11 +138,62 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
return embeddings, nil
}
// TODO: ONNXEmbedder implementation would go here
// This would require:
// 1. Loading ONNX models locally
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
//
// For now, we'll focus on the API implementation which is already working in the current system,
// and can be extended later when we have ONNX runtime integration
type ONNXEmbedder struct {
session *onnxruntime_go.DynamicAdvancedSession
tokenizer *tokenizers.Tokenizer
dims int // 768, 512, 256, or 128 for Matryoshka
}
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
// Batch processing
inputs := e.prepareBatch(texts)
outputs := make([][]float32, len(texts))
// Run batch inference (much faster)
err := e.session.Run(inputs, outputs)
return outputs, err
}
func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) {
// Load ONNX model
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath, // onnx/embedgemma/model_q4.onnx
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
nil,
)
if err != nil {
return nil, err
}
// Load tokenizer (from Hugging Face)
tokenizer, err := tokenizers.FromFile("./tokenizer.json")
return &ONNXEmbedder{
session: session,
tokenizer: tokenizer,
}, nil
}
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
// Tokenize
tokens := e.tokenizer.Encode(text, true)
// Prepare inputs
inputIDs := []int64{tokens.GetIds()}
attentionMask := []int64{tokens.GetAttentionMask()}
// Run inference
output := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, 768),
)
err := e.session.Run(
map[string]any{
"input_ids": inputIDs,
"attention_mask": attentionMask,
},
[]string{"sentence_embedding"},
[]any{&output},
)
return output.GetData(), nil
}

View File

@@ -246,7 +246,7 @@ func (r *RAG) extractImportantPhrases(query string) string {
break
}
}
if isImportant || len(word) > 3 {
if isImportant || len(word) >= 3 {
important = append(important, word)
}
}