diff --git a/rag/embedder.go b/rag/embedder.go index 396f04b..988d91e 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -174,134 +174,115 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { // 1. Tokenize - encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens + encoding, err := e.tokenizer.EncodeSingle(text) if err != nil { return nil, fmt.Errorf("tokenization failed: %w", err) } - // Convert []int32 to []int64 for ONNX - inputIDs := make([]int64, len(encoding.GetIDs())) - for i, id := range encoding.GetIDs() { + // 2. Convert to int64 and create attention mask + ids := encoding.Ids + inputIDs := make([]int64, len(ids)) + attentionMask := make([]int64, len(ids)) + for i, id := range ids { inputIDs[i] = int64(id) + attentionMask[i] = 1 } - attentionMask := make([]int64, len(encoding.GetAttentionMask())) - for i, m := range encoding.GetAttentionMask() { - attentionMask[i] = int64(m) - } - // 2. Create input tensors (shape: [1, seq_len]) + // 3. Create input tensors (shape: [1, seq_len]) seqLen := int64(len(inputIDs)) - inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs) + inputIDsTensor, err := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(1, seqLen), + inputIDs, + ) if err != nil { return nil, fmt.Errorf("failed to create input_ids tensor: %w", err) } defer inputIDsTensor.Destroy() - maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask) + maskTensor, err := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(1, seqLen), + attentionMask, + ) if err != nil { return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err) } defer maskTensor.Destroy() - // 3. Create output tensor (shape: [1, dims]) - outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims))) + // 4. Create output tensor + outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(1, int64(e.dims)), + ) if err != nil { return nil, fmt.Errorf("failed to create output tensor: %w", err) } defer outputTensor.Destroy() - // 4. Run inference + // 5. Run inference err = e.session.Run( - map[string]*onnxruntime_go.Tensor{ - "input_ids": inputIDsTensor, - "attention_mask": maskTensor, - }, + []onnxruntime_go.Value{inputIDsTensor, maskTensor}, []string{"sentence_embedding"}, - []*onnxruntime_go.Tensor{outputTensor}, + []onnxruntime_go.Value{outputTensor}, ) if err != nil { return nil, fmt.Errorf("inference failed: %w", err) } - // 5. Extract data + // 6. Copy output data outputData := outputTensor.GetData() - // outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy. - // We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now). embedding := make([]float32, len(outputData)) copy(embedding, outputData) return embedding, nil } -// EmbedSlice (batch) – to be implemented properly func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { - if len(texts) == 0 { - return nil, nil - } - // 1. Tokenize all texts and find max length for padding encodings := make([]*tokenizer.Encoding, len(texts)) maxLen := 0 for i, txt := range texts { - enc, err := e.tokenizer.Encode(txt, true) + enc, err := e.tokenizer.EncodeSingle(txt) if err != nil { - return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err) + return nil, err } encodings[i] = enc - if l := len(enc.GetIDs()); l > maxLen { + if l := len(enc.Ids); l > maxLen { maxLen = l } } - // 2. Build padded input_ids and attention_mask (shape: [batch, maxLen]) batchSize := len(texts) inputIDs := make([]int64, batchSize*maxLen) attentionMask := make([]int64, batchSize*maxLen) for i, enc := range encodings { - ids := enc.GetIDs() - mask := enc.GetAttentionMask() + ids := enc.Ids offset := i * maxLen - // copy actual tokens - for j := 0; j < len(ids); j++ { - inputIDs[offset+j] = int64(ids[j]) - attentionMask[offset+j] = int64(mask[j]) + for j, id := range ids { + inputIDs[offset+j] = int64(id) + attentionMask[offset+j] = 1 } - // remaining positions (padding) are already zero-initialized + // Remaining positions are already zero (padding) } - // 3. Create tensors - inputIDsTensor, err := onnxruntime_go.NewTensor( + // Create tensors with shape [batchSize, maxLen] + inputTensor, _ := onnxruntime_go.NewTensor[int64]( onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), inputIDs, ) - if err != nil { - return nil, err - } - defer inputIDsTensor.Destroy() - maskTensor, err := onnxruntime_go.NewTensor( + defer inputTensor.Destroy() + maskTensor, _ := onnxruntime_go.NewTensor[int64]( onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), attentionMask, ) - if err != nil { - return nil, err - } defer maskTensor.Destroy() - outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( + outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32]( onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)), ) - if err != nil { - return nil, err - } defer outputTensor.Destroy() - // 4. Run - err = e.session.Run( - map[string]*onnxruntime_go.Tensor{ - "input_ids": inputIDsTensor, - "attention_mask": maskTensor, - }, + err := e.session.Run( + []onnxruntime_go.Value{inputTensor, maskTensor}, []string{"sentence_embedding"}, - []*onnxruntime_go.Tensor{outputTensor}, + []onnxruntime_go.Value{outputTensor}, ) if err != nil { return nil, err } - // 5. Extract batch results - outputData := outputTensor.GetData() + // Extract embeddings per batch item + data := outputTensor.GetData() embeddings := make([][]float32, batchSize) for i := 0; i < batchSize; i++ { start := i * e.dims emb := make([]float32, e.dims) - copy(emb, outputData[start:start+e.dims]) + copy(emb, data[start:start+e.dims]) embeddings[i] = emb } return embeddings, nil