This commit is contained in:
Grail Finder
2026-03-05 14:38:26 +03:00
parent 7c56e27dbe
commit 4bd6883966

View File

@@ -174,134 +174,115 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
// 1. Tokenize // 1. Tokenize
encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens encoding, err := e.tokenizer.EncodeSingle(text)
if err != nil { if err != nil {
return nil, fmt.Errorf("tokenization failed: %w", err) return nil, fmt.Errorf("tokenization failed: %w", err)
} }
// Convert []int32 to []int64 for ONNX // 2. Convert to int64 and create attention mask
inputIDs := make([]int64, len(encoding.GetIDs())) ids := encoding.Ids
for i, id := range encoding.GetIDs() { inputIDs := make([]int64, len(ids))
attentionMask := make([]int64, len(ids))
for i, id := range ids {
inputIDs[i] = int64(id) inputIDs[i] = int64(id)
attentionMask[i] = 1
} }
attentionMask := make([]int64, len(encoding.GetAttentionMask())) // 3. Create input tensors (shape: [1, seq_len])
for i, m := range encoding.GetAttentionMask() {
attentionMask[i] = int64(m)
}
// 2. Create input tensors (shape: [1, seq_len])
seqLen := int64(len(inputIDs)) seqLen := int64(len(inputIDs))
inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs) inputIDsTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, seqLen),
inputIDs,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err) return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
} }
defer inputIDsTensor.Destroy() defer inputIDsTensor.Destroy()
maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask) maskTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, seqLen),
attentionMask,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err) return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
} }
defer maskTensor.Destroy() defer maskTensor.Destroy()
// 3. Create output tensor (shape: [1, dims]) // 4. Create output tensor
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims))) outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, int64(e.dims)),
)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create output tensor: %w", err) return nil, fmt.Errorf("failed to create output tensor: %w", err)
} }
defer outputTensor.Destroy() defer outputTensor.Destroy()
// 4. Run inference // 5. Run inference
err = e.session.Run( err = e.session.Run(
map[string]*onnxruntime_go.Tensor{ []onnxruntime_go.Value{inputIDsTensor, maskTensor},
"input_ids": inputIDsTensor,
"attention_mask": maskTensor,
},
[]string{"sentence_embedding"}, []string{"sentence_embedding"},
[]*onnxruntime_go.Tensor{outputTensor}, []onnxruntime_go.Value{outputTensor},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("inference failed: %w", err) return nil, fmt.Errorf("inference failed: %w", err)
} }
// 5. Extract data // 6. Copy output data
outputData := outputTensor.GetData() outputData := outputTensor.GetData()
// outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
// We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
embedding := make([]float32, len(outputData)) embedding := make([]float32, len(outputData))
copy(embedding, outputData) copy(embedding, outputData)
return embedding, nil return embedding, nil
} }
// EmbedSlice (batch) to be implemented properly
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// 1. Tokenize all texts and find max length for padding
encodings := make([]*tokenizer.Encoding, len(texts)) encodings := make([]*tokenizer.Encoding, len(texts))
maxLen := 0 maxLen := 0
for i, txt := range texts { for i, txt := range texts {
enc, err := e.tokenizer.Encode(txt, true) enc, err := e.tokenizer.EncodeSingle(txt)
if err != nil { if err != nil {
return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err) return nil, err
} }
encodings[i] = enc encodings[i] = enc
if l := len(enc.GetIDs()); l > maxLen { if l := len(enc.Ids); l > maxLen {
maxLen = l maxLen = l
} }
} }
// 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
batchSize := len(texts) batchSize := len(texts)
inputIDs := make([]int64, batchSize*maxLen) inputIDs := make([]int64, batchSize*maxLen)
attentionMask := make([]int64, batchSize*maxLen) attentionMask := make([]int64, batchSize*maxLen)
for i, enc := range encodings { for i, enc := range encodings {
ids := enc.GetIDs() ids := enc.Ids
mask := enc.GetAttentionMask()
offset := i * maxLen offset := i * maxLen
// copy actual tokens for j, id := range ids {
for j := 0; j < len(ids); j++ { inputIDs[offset+j] = int64(id)
inputIDs[offset+j] = int64(ids[j]) attentionMask[offset+j] = 1
attentionMask[offset+j] = int64(mask[j])
} }
// remaining positions (padding) are already zero-initialized // Remaining positions are already zero (padding)
} }
// 3. Create tensors // Create tensors with shape [batchSize, maxLen]
inputIDsTensor, err := onnxruntime_go.NewTensor( inputTensor, _ := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
inputIDs, inputIDs,
) )
if err != nil { defer inputTensor.Destroy()
return nil, err maskTensor, _ := onnxruntime_go.NewTensor[int64](
}
defer inputIDsTensor.Destroy()
maskTensor, err := onnxruntime_go.NewTensor(
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
attentionMask, attentionMask,
) )
if err != nil {
return nil, err
}
defer maskTensor.Destroy() defer maskTensor.Destroy()
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)), onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
) )
if err != nil {
return nil, err
}
defer outputTensor.Destroy() defer outputTensor.Destroy()
// 4. Run err := e.session.Run(
err = e.session.Run( []onnxruntime_go.Value{inputTensor, maskTensor},
map[string]*onnxruntime_go.Tensor{
"input_ids": inputIDsTensor,
"attention_mask": maskTensor,
},
[]string{"sentence_embedding"}, []string{"sentence_embedding"},
[]*onnxruntime_go.Tensor{outputTensor}, []onnxruntime_go.Value{outputTensor},
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 5. Extract batch results // Extract embeddings per batch item
outputData := outputTensor.GetData() data := outputTensor.GetData()
embeddings := make([][]float32, batchSize) embeddings := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ { for i := 0; i < batchSize; i++ {
start := i * e.dims start := i * e.dims
emb := make([]float32, e.dims) emb := make([]float32, e.dims)
copy(emb, outputData[start:start+e.dims]) copy(emb, data[start:start+e.dims])
embeddings[i] = emb embeddings[i] = emb
} }
return embeddings, nil return embeddings, nil