Enha: local onnx
This commit is contained in:
@@ -9,6 +9,10 @@ import (
|
||||
"gf-lt/models"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/takara-ai/go-tokenizers/tokenizers"
|
||||
|
||||
"github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
// Embedder defines the interface for embedding text
|
||||
@@ -134,11 +138,62 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// TODO: ONNXEmbedder implementation would go here
|
||||
// This would require:
|
||||
// 1. Loading ONNX models locally
|
||||
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
||||
// 3. Converting text to embeddings without external API calls
|
||||
//
|
||||
// For now, we'll focus on the API implementation which is already working in the current system,
|
||||
// and can be extended later when we have ONNX runtime integration
|
||||
|
||||
type ONNXEmbedder struct {
|
||||
session *onnxruntime_go.DynamicAdvancedSession
|
||||
tokenizer *tokenizers.Tokenizer
|
||||
dims int // 768, 512, 256, or 128 for Matryoshka
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
||||
// Batch processing
|
||||
inputs := e.prepareBatch(texts)
|
||||
outputs := make([][]float32, len(texts))
|
||||
|
||||
// Run batch inference (much faster)
|
||||
err := e.session.Run(inputs, outputs)
|
||||
return outputs, err
|
||||
}
|
||||
|
||||
func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) {
|
||||
// Load ONNX model
|
||||
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
||||
modelPath, // onnx/embedgemma/model_q4.onnx
|
||||
[]string{"input_ids", "attention_mask"},
|
||||
[]string{"sentence_embedding"},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Load tokenizer (from Hugging Face)
|
||||
tokenizer, err := tokenizers.FromFile("./tokenizer.json")
|
||||
return &ONNXEmbedder{
|
||||
session: session,
|
||||
tokenizer: tokenizer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||
// Tokenize
|
||||
tokens := e.tokenizer.Encode(text, true)
|
||||
// Prepare inputs
|
||||
inputIDs := []int64{tokens.GetIds()}
|
||||
attentionMask := []int64{tokens.GetAttentionMask()}
|
||||
// Run inference
|
||||
output := onnxruntime_go.NewEmptyTensor[float32](
|
||||
onnxruntime_go.NewShape(1, 768),
|
||||
)
|
||||
err := e.session.Run(
|
||||
map[string]any{
|
||||
"input_ids": inputIDs,
|
||||
"attention_mask": attentionMask,
|
||||
},
|
||||
[]string{"sentence_embedding"},
|
||||
[]any{&output},
|
||||
)
|
||||
return output.GetData(), nil
|
||||
}
|
||||
|
||||
@@ -246,7 +246,7 @@ func (r *RAG) extractImportantPhrases(query string) string {
|
||||
break
|
||||
}
|
||||
}
|
||||
if isImportant || len(word) > 3 {
|
||||
if isImportant || len(word) >= 3 {
|
||||
important = append(important, word)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user