package rag import ( "bytes" "encoding/json" "errors" "fmt" "gf-lt/config" "gf-lt/models" "log/slog" "net/http" "github.com/sugarme/tokenizer" "github.com/sugarme/tokenizer/pretrained" "github.com/yalue/onnxruntime_go" ) // Embedder defines the interface for embedding text type Embedder interface { Embed(text string) ([]float32, error) EmbedSlice(lines []string) ([][]float32, error) } // APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.) type APIEmbedder struct { logger *slog.Logger client *http.Client cfg *config.Config } func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder { return &APIEmbedder{ logger: l, client: &http.Client{}, cfg: cfg, } } func (a *APIEmbedder) Embed(text string) ([]float32, error) { payload, err := json.Marshal( map[string]any{"input": text, "encoding_format": "float"}, ) if err != nil { a.logger.Error("failed to marshal payload", "err", err.Error()) return nil, err } req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) if err != nil { a.logger.Error("failed to create new req", "err", err.Error()) return nil, err } if a.cfg.HFToken != "" { req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) } resp, err := a.client.Do(req) if err != nil { a.logger.Error("failed to embed text", "err", err.Error()) return nil, err } defer resp.Body.Close() if resp.StatusCode != 200 { err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) a.logger.Error(err.Error()) return nil, err } embResp := &models.LCPEmbedResp{} if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { a.logger.Error("failed to decode embedding response", "err", err.Error()) return nil, err } if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 { err = errors.New("empty embedding response") a.logger.Error("empty embedding response") return nil, err } return embResp.Data[0].Embedding, nil } func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { payload, err := json.Marshal( map[string]any{"input": lines, "encoding_format": "float"}, ) if err != nil { a.logger.Error("failed to marshal payload", "err", err.Error()) return nil, err } req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) if err != nil { a.logger.Error("failed to create new req", "err", err.Error()) return nil, err } if a.cfg.HFToken != "" { req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) } resp, err := a.client.Do(req) if err != nil { a.logger.Error("failed to embed text", "err", err.Error()) return nil, err } defer resp.Body.Close() if resp.StatusCode != 200 { err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) a.logger.Error(err.Error()) return nil, err } embResp := &models.LCPEmbedResp{} if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { a.logger.Error("failed to decode embedding response", "err", err.Error()) return nil, err } if len(embResp.Data) == 0 { err = errors.New("empty embedding response") a.logger.Error("empty embedding response") return nil, err } // Collect all embeddings from the response embeddings := make([][]float32, len(embResp.Data)) for i := range embResp.Data { if len(embResp.Data[i].Embedding) == 0 { err = fmt.Errorf("empty embedding at index %d", i) a.logger.Error("empty embedding", "index", i) return nil, err } embeddings[i] = embResp.Data[i].Embedding } // Sort embeddings by index to match the order of input lines // API responses may not be in order for _, data := range embResp.Data { if data.Index >= len(embeddings) || data.Index < 0 { err = fmt.Errorf("invalid embedding index %d", data.Index) a.logger.Error("invalid embedding index", "index", data.Index) return nil, err } embeddings[data.Index] = data.Embedding } return embeddings, nil } // 1. Loading ONNX models locally // 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) // 3. Converting text to embeddings without external API calls type ONNXEmbedder struct { session *onnxruntime_go.DynamicAdvancedSession tokenizer *tokenizer.Tokenizer dims int // embedding dimension (e.g., 768) logger *slog.Logger } func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { // Load tokenizer using sugarme/tokenizer tok, err := pretrained.FromFile(tokenizerPath) if err != nil { return nil, fmt.Errorf("failed to load tokenizer: %w", err) } // Create ONNX session session, err := onnxruntime_go.NewDynamicAdvancedSession( modelPath, // onnx/embedgemma/model_q4.onnx []string{"input_ids", "attention_mask"}, []string{"sentence_embedding"}, nil, // optional options ) if err != nil { return nil, fmt.Errorf("failed to create ONNX session: %w", err) } return &ONNXEmbedder{ session: session, tokenizer: tok, dims: dims, logger: logger, }, nil } func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { // 1. Tokenize encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens if err != nil { return nil, fmt.Errorf("tokenization failed: %w", err) } // Convert []int32 to []int64 for ONNX inputIDs := make([]int64, len(encoding.GetIDs())) for i, id := range encoding.GetIDs() { inputIDs[i] = int64(id) } attentionMask := make([]int64, len(encoding.GetAttentionMask())) for i, m := range encoding.GetAttentionMask() { attentionMask[i] = int64(m) } // 2. Create input tensors (shape: [1, seq_len]) seqLen := int64(len(inputIDs)) inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs) if err != nil { return nil, fmt.Errorf("failed to create input_ids tensor: %w", err) } defer inputIDsTensor.Destroy() maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask) if err != nil { return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err) } defer maskTensor.Destroy() // 3. Create output tensor (shape: [1, dims]) outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims))) if err != nil { return nil, fmt.Errorf("failed to create output tensor: %w", err) } defer outputTensor.Destroy() // 4. Run inference err = e.session.Run( map[string]*onnxruntime_go.Tensor{ "input_ids": inputIDsTensor, "attention_mask": maskTensor, }, []string{"sentence_embedding"}, []*onnxruntime_go.Tensor{outputTensor}, ) if err != nil { return nil, fmt.Errorf("inference failed: %w", err) } // 5. Extract data outputData := outputTensor.GetData() // outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy. // We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now). embedding := make([]float32, len(outputData)) copy(embedding, outputData) return embedding, nil } // EmbedSlice (batch) – to be implemented properly func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { if len(texts) == 0 { return nil, nil } // 1. Tokenize all texts and find max length for padding encodings := make([]*tokenizer.Encoding, len(texts)) maxLen := 0 for i, txt := range texts { enc, err := e.tokenizer.Encode(txt, true) if err != nil { return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err) } encodings[i] = enc if l := len(enc.GetIDs()); l > maxLen { maxLen = l } } // 2. Build padded input_ids and attention_mask (shape: [batch, maxLen]) batchSize := len(texts) inputIDs := make([]int64, batchSize*maxLen) attentionMask := make([]int64, batchSize*maxLen) for i, enc := range encodings { ids := enc.GetIDs() mask := enc.GetAttentionMask() offset := i * maxLen // copy actual tokens for j := 0; j < len(ids); j++ { inputIDs[offset+j] = int64(ids[j]) attentionMask[offset+j] = int64(mask[j]) } // remaining positions (padding) are already zero-initialized } // 3. Create tensors inputIDsTensor, err := onnxruntime_go.NewTensor( onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), inputIDs, ) if err != nil { return nil, err } defer inputIDsTensor.Destroy() maskTensor, err := onnxruntime_go.NewTensor( onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), attentionMask, ) if err != nil { return nil, err } defer maskTensor.Destroy() outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)), ) if err != nil { return nil, err } defer outputTensor.Destroy() // 4. Run err = e.session.Run( map[string]*onnxruntime_go.Tensor{ "input_ids": inputIDsTensor, "attention_mask": maskTensor, }, []string{"sentence_embedding"}, []*onnxruntime_go.Tensor{outputTensor}, ) if err != nil { return nil, err } // 5. Extract batch results outputData := outputTensor.GetData() embeddings := make([][]float32, batchSize) for i := 0; i < batchSize; i++ { start := i * e.dims emb := make([]float32, e.dims) copy(emb, outputData[start:start+e.dims]) embeddings[i] = emb } return embeddings, nil }