WIP
This commit is contained in:
105
rag/embedder.go
105
rag/embedder.go
@@ -174,134 +174,115 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
|
|||||||
|
|
||||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||||
// 1. Tokenize
|
// 1. Tokenize
|
||||||
encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens
|
encoding, err := e.tokenizer.EncodeSingle(text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("tokenization failed: %w", err)
|
return nil, fmt.Errorf("tokenization failed: %w", err)
|
||||||
}
|
}
|
||||||
// Convert []int32 to []int64 for ONNX
|
// 2. Convert to int64 and create attention mask
|
||||||
inputIDs := make([]int64, len(encoding.GetIDs()))
|
ids := encoding.Ids
|
||||||
for i, id := range encoding.GetIDs() {
|
inputIDs := make([]int64, len(ids))
|
||||||
|
attentionMask := make([]int64, len(ids))
|
||||||
|
for i, id := range ids {
|
||||||
inputIDs[i] = int64(id)
|
inputIDs[i] = int64(id)
|
||||||
|
attentionMask[i] = 1
|
||||||
}
|
}
|
||||||
attentionMask := make([]int64, len(encoding.GetAttentionMask()))
|
// 3. Create input tensors (shape: [1, seq_len])
|
||||||
for i, m := range encoding.GetAttentionMask() {
|
|
||||||
attentionMask[i] = int64(m)
|
|
||||||
}
|
|
||||||
// 2. Create input tensors (shape: [1, seq_len])
|
|
||||||
seqLen := int64(len(inputIDs))
|
seqLen := int64(len(inputIDs))
|
||||||
inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs)
|
inputIDsTensor, err := onnxruntime_go.NewTensor[int64](
|
||||||
|
onnxruntime_go.NewShape(1, seqLen),
|
||||||
|
inputIDs,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
||||||
}
|
}
|
||||||
defer inputIDsTensor.Destroy()
|
defer inputIDsTensor.Destroy()
|
||||||
maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask)
|
maskTensor, err := onnxruntime_go.NewTensor[int64](
|
||||||
|
onnxruntime_go.NewShape(1, seqLen),
|
||||||
|
attentionMask,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
||||||
}
|
}
|
||||||
defer maskTensor.Destroy()
|
defer maskTensor.Destroy()
|
||||||
// 3. Create output tensor (shape: [1, dims])
|
// 4. Create output tensor
|
||||||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims)))
|
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
||||||
|
onnxruntime_go.NewShape(1, int64(e.dims)),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
||||||
}
|
}
|
||||||
defer outputTensor.Destroy()
|
defer outputTensor.Destroy()
|
||||||
// 4. Run inference
|
// 5. Run inference
|
||||||
err = e.session.Run(
|
err = e.session.Run(
|
||||||
map[string]*onnxruntime_go.Tensor{
|
[]onnxruntime_go.Value{inputIDsTensor, maskTensor},
|
||||||
"input_ids": inputIDsTensor,
|
|
||||||
"attention_mask": maskTensor,
|
|
||||||
},
|
|
||||||
[]string{"sentence_embedding"},
|
[]string{"sentence_embedding"},
|
||||||
[]*onnxruntime_go.Tensor{outputTensor},
|
[]onnxruntime_go.Value{outputTensor},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("inference failed: %w", err)
|
return nil, fmt.Errorf("inference failed: %w", err)
|
||||||
}
|
}
|
||||||
// 5. Extract data
|
// 6. Copy output data
|
||||||
outputData := outputTensor.GetData()
|
outputData := outputTensor.GetData()
|
||||||
// outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
|
|
||||||
// We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
|
|
||||||
embedding := make([]float32, len(outputData))
|
embedding := make([]float32, len(outputData))
|
||||||
copy(embedding, outputData)
|
copy(embedding, outputData)
|
||||||
return embedding, nil
|
return embedding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbedSlice (batch) – to be implemented properly
|
|
||||||
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
||||||
if len(texts) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
// 1. Tokenize all texts and find max length for padding
|
|
||||||
encodings := make([]*tokenizer.Encoding, len(texts))
|
encodings := make([]*tokenizer.Encoding, len(texts))
|
||||||
maxLen := 0
|
maxLen := 0
|
||||||
for i, txt := range texts {
|
for i, txt := range texts {
|
||||||
enc, err := e.tokenizer.Encode(txt, true)
|
enc, err := e.tokenizer.EncodeSingle(txt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err)
|
return nil, err
|
||||||
}
|
}
|
||||||
encodings[i] = enc
|
encodings[i] = enc
|
||||||
if l := len(enc.GetIDs()); l > maxLen {
|
if l := len(enc.Ids); l > maxLen {
|
||||||
maxLen = l
|
maxLen = l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
|
|
||||||
batchSize := len(texts)
|
batchSize := len(texts)
|
||||||
inputIDs := make([]int64, batchSize*maxLen)
|
inputIDs := make([]int64, batchSize*maxLen)
|
||||||
attentionMask := make([]int64, batchSize*maxLen)
|
attentionMask := make([]int64, batchSize*maxLen)
|
||||||
for i, enc := range encodings {
|
for i, enc := range encodings {
|
||||||
ids := enc.GetIDs()
|
ids := enc.Ids
|
||||||
mask := enc.GetAttentionMask()
|
|
||||||
offset := i * maxLen
|
offset := i * maxLen
|
||||||
// copy actual tokens
|
for j, id := range ids {
|
||||||
for j := 0; j < len(ids); j++ {
|
inputIDs[offset+j] = int64(id)
|
||||||
inputIDs[offset+j] = int64(ids[j])
|
attentionMask[offset+j] = 1
|
||||||
attentionMask[offset+j] = int64(mask[j])
|
|
||||||
}
|
}
|
||||||
// remaining positions (padding) are already zero-initialized
|
// Remaining positions are already zero (padding)
|
||||||
}
|
}
|
||||||
// 3. Create tensors
|
// Create tensors with shape [batchSize, maxLen]
|
||||||
inputIDsTensor, err := onnxruntime_go.NewTensor(
|
inputTensor, _ := onnxruntime_go.NewTensor[int64](
|
||||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||||
inputIDs,
|
inputIDs,
|
||||||
)
|
)
|
||||||
if err != nil {
|
defer inputTensor.Destroy()
|
||||||
return nil, err
|
maskTensor, _ := onnxruntime_go.NewTensor[int64](
|
||||||
}
|
|
||||||
defer inputIDsTensor.Destroy()
|
|
||||||
maskTensor, err := onnxruntime_go.NewTensor(
|
|
||||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||||
attentionMask,
|
attentionMask,
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer maskTensor.Destroy()
|
defer maskTensor.Destroy()
|
||||||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
|
||||||
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer outputTensor.Destroy()
|
defer outputTensor.Destroy()
|
||||||
// 4. Run
|
err := e.session.Run(
|
||||||
err = e.session.Run(
|
[]onnxruntime_go.Value{inputTensor, maskTensor},
|
||||||
map[string]*onnxruntime_go.Tensor{
|
|
||||||
"input_ids": inputIDsTensor,
|
|
||||||
"attention_mask": maskTensor,
|
|
||||||
},
|
|
||||||
[]string{"sentence_embedding"},
|
[]string{"sentence_embedding"},
|
||||||
[]*onnxruntime_go.Tensor{outputTensor},
|
[]onnxruntime_go.Value{outputTensor},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// 5. Extract batch results
|
// Extract embeddings per batch item
|
||||||
outputData := outputTensor.GetData()
|
data := outputTensor.GetData()
|
||||||
embeddings := make([][]float32, batchSize)
|
embeddings := make([][]float32, batchSize)
|
||||||
for i := 0; i < batchSize; i++ {
|
for i := 0; i < batchSize; i++ {
|
||||||
start := i * e.dims
|
start := i * e.dims
|
||||||
emb := make([]float32, e.dims)
|
emb := make([]float32, e.dims)
|
||||||
copy(emb, outputData[start:start+e.dims])
|
copy(emb, data[start:start+e.dims])
|
||||||
embeddings[i] = emb
|
embeddings[i] = emb
|
||||||
}
|
}
|
||||||
return embeddings, nil
|
return embeddings, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user