Dep: trying sugarme tokenizer
This commit is contained in:
183
rag/embedder.go
183
rag/embedder.go
@@ -10,8 +10,8 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/takara-ai/go-tokenizers/tokenizers"
|
||||
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
"github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
@@ -141,59 +141,168 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
||||
// 1. Loading ONNX models locally
|
||||
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
||||
// 3. Converting text to embeddings without external API calls
|
||||
|
||||
type ONNXEmbedder struct {
|
||||
session *onnxruntime_go.DynamicAdvancedSession
|
||||
tokenizer *tokenizers.Tokenizer
|
||||
dims int // 768, 512, 256, or 128 for Matryoshka
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
dims int // embedding dimension (e.g., 768)
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
||||
// Batch processing
|
||||
inputs := e.prepareBatch(texts)
|
||||
outputs := make([][]float32, len(texts))
|
||||
|
||||
// Run batch inference (much faster)
|
||||
err := e.session.Run(inputs, outputs)
|
||||
return outputs, err
|
||||
}
|
||||
|
||||
func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) {
|
||||
// Load ONNX model
|
||||
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
||||
// Load tokenizer using sugarme/tokenizer
|
||||
tok, err := pretrained.FromFile(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
||||
}
|
||||
// Create ONNX session
|
||||
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
||||
modelPath, // onnx/embedgemma/model_q4.onnx
|
||||
[]string{"input_ids", "attention_mask"},
|
||||
[]string{"sentence_embedding"},
|
||||
nil,
|
||||
nil, // optional options
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
|
||||
}
|
||||
// Load tokenizer (from Hugging Face)
|
||||
tokenizer, err := tokenizers.FromFile("./tokenizer.json")
|
||||
return &ONNXEmbedder{
|
||||
session: session,
|
||||
tokenizer: tokenizer,
|
||||
tokenizer: tok,
|
||||
dims: dims,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||
// Tokenize
|
||||
tokens := e.tokenizer.Encode(text, true)
|
||||
// Prepare inputs
|
||||
inputIDs := []int64{tokens.GetIds()}
|
||||
attentionMask := []int64{tokens.GetAttentionMask()}
|
||||
// Run inference
|
||||
output := onnxruntime_go.NewEmptyTensor[float32](
|
||||
onnxruntime_go.NewShape(1, 768),
|
||||
)
|
||||
err := e.session.Run(
|
||||
map[string]any{
|
||||
"input_ids": inputIDs,
|
||||
"attention_mask": attentionMask,
|
||||
// 1. Tokenize
|
||||
encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenization failed: %w", err)
|
||||
}
|
||||
// Convert []int32 to []int64 for ONNX
|
||||
inputIDs := make([]int64, len(encoding.GetIDs()))
|
||||
for i, id := range encoding.GetIDs() {
|
||||
inputIDs[i] = int64(id)
|
||||
}
|
||||
attentionMask := make([]int64, len(encoding.GetAttentionMask()))
|
||||
for i, m := range encoding.GetAttentionMask() {
|
||||
attentionMask[i] = int64(m)
|
||||
}
|
||||
// 2. Create input tensors (shape: [1, seq_len])
|
||||
seqLen := int64(len(inputIDs))
|
||||
inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
||||
}
|
||||
defer inputIDsTensor.Destroy()
|
||||
maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
||||
}
|
||||
defer maskTensor.Destroy()
|
||||
// 3. Create output tensor (shape: [1, dims])
|
||||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
// 4. Run inference
|
||||
err = e.session.Run(
|
||||
map[string]*onnxruntime_go.Tensor{
|
||||
"input_ids": inputIDsTensor,
|
||||
"attention_mask": maskTensor,
|
||||
},
|
||||
[]string{"sentence_embedding"},
|
||||
[]any{&output},
|
||||
[]*onnxruntime_go.Tensor{outputTensor},
|
||||
)
|
||||
return output.GetData(), nil
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inference failed: %w", err)
|
||||
}
|
||||
// 5. Extract data
|
||||
outputData := outputTensor.GetData()
|
||||
// outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
|
||||
// We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
|
||||
embedding := make([]float32, len(outputData))
|
||||
copy(embedding, outputData)
|
||||
return embedding, nil
|
||||
}
|
||||
|
||||
// EmbedSlice (batch) – to be implemented properly
|
||||
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// 1. Tokenize all texts and find max length for padding
|
||||
encodings := make([]*tokenizer.Encoding, len(texts))
|
||||
maxLen := 0
|
||||
for i, txt := range texts {
|
||||
enc, err := e.tokenizer.Encode(txt, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err)
|
||||
}
|
||||
encodings[i] = enc
|
||||
if l := len(enc.GetIDs()); l > maxLen {
|
||||
maxLen = l
|
||||
}
|
||||
}
|
||||
// 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
|
||||
batchSize := len(texts)
|
||||
inputIDs := make([]int64, batchSize*maxLen)
|
||||
attentionMask := make([]int64, batchSize*maxLen)
|
||||
for i, enc := range encodings {
|
||||
ids := enc.GetIDs()
|
||||
mask := enc.GetAttentionMask()
|
||||
offset := i * maxLen
|
||||
// copy actual tokens
|
||||
for j := 0; j < len(ids); j++ {
|
||||
inputIDs[offset+j] = int64(ids[j])
|
||||
attentionMask[offset+j] = int64(mask[j])
|
||||
}
|
||||
// remaining positions (padding) are already zero-initialized
|
||||
}
|
||||
// 3. Create tensors
|
||||
inputIDsTensor, err := onnxruntime_go.NewTensor(
|
||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||
inputIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer inputIDsTensor.Destroy()
|
||||
maskTensor, err := onnxruntime_go.NewTensor(
|
||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||
attentionMask,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer maskTensor.Destroy()
|
||||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
||||
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
// 4. Run
|
||||
err = e.session.Run(
|
||||
map[string]*onnxruntime_go.Tensor{
|
||||
"input_ids": inputIDsTensor,
|
||||
"attention_mask": maskTensor,
|
||||
},
|
||||
[]string{"sentence_embedding"},
|
||||
[]*onnxruntime_go.Tensor{outputTensor},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 5. Extract batch results
|
||||
outputData := outputTensor.GetData()
|
||||
embeddings := make([][]float32, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
start := i * e.dims
|
||||
emb := make([]float32, e.dims)
|
||||
copy(emb, outputData[start:start+e.dims])
|
||||
embeddings[i] = emb
|
||||
}
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user