Files
gf-lt/rag/embedder.go
2026-03-05 14:27:19 +03:00

309 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package rag
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"gf-lt/config"
"gf-lt/models"
"log/slog"
"net/http"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
"github.com/yalue/onnxruntime_go"
)
// Embedder defines the interface for embedding text
type Embedder interface {
Embed(text string) ([]float32, error)
EmbedSlice(lines []string) ([][]float32, error)
}
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
type APIEmbedder struct {
logger *slog.Logger
client *http.Client
cfg *config.Config
}
func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
return &APIEmbedder{
logger: l,
client: &http.Client{},
cfg: cfg,
}
}
func (a *APIEmbedder) Embed(text string) ([]float32, error) {
payload, err := json.Marshal(
map[string]any{"input": text, "encoding_format": "float"},
)
if err != nil {
a.logger.Error("failed to marshal payload", "err", err.Error())
return nil, err
}
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
if err != nil {
a.logger.Error("failed to create new req", "err", err.Error())
return nil, err
}
if a.cfg.HFToken != "" {
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
}
resp, err := a.client.Do(req)
if err != nil {
a.logger.Error("failed to embed text", "err", err.Error())
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
a.logger.Error(err.Error())
return nil, err
}
embResp := &models.LCPEmbedResp{}
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
a.logger.Error("failed to decode embedding response", "err", err.Error())
return nil, err
}
if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
err = errors.New("empty embedding response")
a.logger.Error("empty embedding response")
return nil, err
}
return embResp.Data[0].Embedding, nil
}
func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
payload, err := json.Marshal(
map[string]any{"input": lines, "encoding_format": "float"},
)
if err != nil {
a.logger.Error("failed to marshal payload", "err", err.Error())
return nil, err
}
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
if err != nil {
a.logger.Error("failed to create new req", "err", err.Error())
return nil, err
}
if a.cfg.HFToken != "" {
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
}
resp, err := a.client.Do(req)
if err != nil {
a.logger.Error("failed to embed text", "err", err.Error())
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
a.logger.Error(err.Error())
return nil, err
}
embResp := &models.LCPEmbedResp{}
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
a.logger.Error("failed to decode embedding response", "err", err.Error())
return nil, err
}
if len(embResp.Data) == 0 {
err = errors.New("empty embedding response")
a.logger.Error("empty embedding response")
return nil, err
}
// Collect all embeddings from the response
embeddings := make([][]float32, len(embResp.Data))
for i := range embResp.Data {
if len(embResp.Data[i].Embedding) == 0 {
err = fmt.Errorf("empty embedding at index %d", i)
a.logger.Error("empty embedding", "index", i)
return nil, err
}
embeddings[i] = embResp.Data[i].Embedding
}
// Sort embeddings by index to match the order of input lines
// API responses may not be in order
for _, data := range embResp.Data {
if data.Index >= len(embeddings) || data.Index < 0 {
err = fmt.Errorf("invalid embedding index %d", data.Index)
a.logger.Error("invalid embedding index", "index", data.Index)
return nil, err
}
embeddings[data.Index] = data.Embedding
}
return embeddings, nil
}
// 1. Loading ONNX models locally
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
type ONNXEmbedder struct {
session *onnxruntime_go.DynamicAdvancedSession
tokenizer *tokenizer.Tokenizer
dims int // embedding dimension (e.g., 768)
logger *slog.Logger
}
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
// Load tokenizer using sugarme/tokenizer
tok, err := pretrained.FromFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
}
// Create ONNX session
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath, // onnx/embedgemma/model_q4.onnx
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
nil, // optional options
)
if err != nil {
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
}
return &ONNXEmbedder{
session: session,
tokenizer: tok,
dims: dims,
logger: logger,
}, nil
}
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
// 1. Tokenize
encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens
if err != nil {
return nil, fmt.Errorf("tokenization failed: %w", err)
}
// Convert []int32 to []int64 for ONNX
inputIDs := make([]int64, len(encoding.GetIDs()))
for i, id := range encoding.GetIDs() {
inputIDs[i] = int64(id)
}
attentionMask := make([]int64, len(encoding.GetAttentionMask()))
for i, m := range encoding.GetAttentionMask() {
attentionMask[i] = int64(m)
}
// 2. Create input tensors (shape: [1, seq_len])
seqLen := int64(len(inputIDs))
inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs)
if err != nil {
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
}
defer inputIDsTensor.Destroy()
maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask)
if err != nil {
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
}
defer maskTensor.Destroy()
// 3. Create output tensor (shape: [1, dims])
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims)))
if err != nil {
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer outputTensor.Destroy()
// 4. Run inference
err = e.session.Run(
map[string]*onnxruntime_go.Tensor{
"input_ids": inputIDsTensor,
"attention_mask": maskTensor,
},
[]string{"sentence_embedding"},
[]*onnxruntime_go.Tensor{outputTensor},
)
if err != nil {
return nil, fmt.Errorf("inference failed: %w", err)
}
// 5. Extract data
outputData := outputTensor.GetData()
// outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
// We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
embedding := make([]float32, len(outputData))
copy(embedding, outputData)
return embedding, nil
}
// EmbedSlice (batch) to be implemented properly
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// 1. Tokenize all texts and find max length for padding
encodings := make([]*tokenizer.Encoding, len(texts))
maxLen := 0
for i, txt := range texts {
enc, err := e.tokenizer.Encode(txt, true)
if err != nil {
return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err)
}
encodings[i] = enc
if l := len(enc.GetIDs()); l > maxLen {
maxLen = l
}
}
// 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
batchSize := len(texts)
inputIDs := make([]int64, batchSize*maxLen)
attentionMask := make([]int64, batchSize*maxLen)
for i, enc := range encodings {
ids := enc.GetIDs()
mask := enc.GetAttentionMask()
offset := i * maxLen
// copy actual tokens
for j := 0; j < len(ids); j++ {
inputIDs[offset+j] = int64(ids[j])
attentionMask[offset+j] = int64(mask[j])
}
// remaining positions (padding) are already zero-initialized
}
// 3. Create tensors
inputIDsTensor, err := onnxruntime_go.NewTensor(
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
inputIDs,
)
if err != nil {
return nil, err
}
defer inputIDsTensor.Destroy()
maskTensor, err := onnxruntime_go.NewTensor(
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
attentionMask,
)
if err != nil {
return nil, err
}
defer maskTensor.Destroy()
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
)
if err != nil {
return nil, err
}
defer outputTensor.Destroy()
// 4. Run
err = e.session.Run(
map[string]*onnxruntime_go.Tensor{
"input_ids": inputIDsTensor,
"attention_mask": maskTensor,
},
[]string{"sentence_embedding"},
[]*onnxruntime_go.Tensor{outputTensor},
)
if err != nil {
return nil, err
}
// 5. Extract batch results
outputData := outputTensor.GetData()
embeddings := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ {
start := i * e.dims
emb := make([]float32, e.dims)
copy(emb, outputData[start:start+e.dims])
embeddings[i] = emb
}
return embeddings, nil
}