309 lines
9.0 KiB
Go
309 lines
9.0 KiB
Go
package rag
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"gf-lt/config"
|
||
"gf-lt/models"
|
||
"log/slog"
|
||
"net/http"
|
||
|
||
"github.com/sugarme/tokenizer"
|
||
"github.com/sugarme/tokenizer/pretrained"
|
||
"github.com/yalue/onnxruntime_go"
|
||
)
|
||
|
||
// Embedder defines the interface for embedding text
|
||
type Embedder interface {
|
||
Embed(text string) ([]float32, error)
|
||
EmbedSlice(lines []string) ([][]float32, error)
|
||
}
|
||
|
||
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
|
||
type APIEmbedder struct {
|
||
logger *slog.Logger
|
||
client *http.Client
|
||
cfg *config.Config
|
||
}
|
||
|
||
func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
|
||
return &APIEmbedder{
|
||
logger: l,
|
||
client: &http.Client{},
|
||
cfg: cfg,
|
||
}
|
||
}
|
||
|
||
func (a *APIEmbedder) Embed(text string) ([]float32, error) {
|
||
payload, err := json.Marshal(
|
||
map[string]any{"input": text, "encoding_format": "float"},
|
||
)
|
||
if err != nil {
|
||
a.logger.Error("failed to marshal payload", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
||
if err != nil {
|
||
a.logger.Error("failed to create new req", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
if a.cfg.HFToken != "" {
|
||
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
||
}
|
||
resp, err := a.client.Do(req)
|
||
if err != nil {
|
||
a.logger.Error("failed to embed text", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != 200 {
|
||
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
||
a.logger.Error(err.Error())
|
||
return nil, err
|
||
}
|
||
embResp := &models.LCPEmbedResp{}
|
||
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
||
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
|
||
err = errors.New("empty embedding response")
|
||
a.logger.Error("empty embedding response")
|
||
return nil, err
|
||
}
|
||
return embResp.Data[0].Embedding, nil
|
||
}
|
||
|
||
func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
||
payload, err := json.Marshal(
|
||
map[string]any{"input": lines, "encoding_format": "float"},
|
||
)
|
||
if err != nil {
|
||
a.logger.Error("failed to marshal payload", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
||
if err != nil {
|
||
a.logger.Error("failed to create new req", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
if a.cfg.HFToken != "" {
|
||
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
||
}
|
||
resp, err := a.client.Do(req)
|
||
if err != nil {
|
||
a.logger.Error("failed to embed text", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != 200 {
|
||
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
||
a.logger.Error(err.Error())
|
||
return nil, err
|
||
}
|
||
embResp := &models.LCPEmbedResp{}
|
||
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
||
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
||
return nil, err
|
||
}
|
||
if len(embResp.Data) == 0 {
|
||
err = errors.New("empty embedding response")
|
||
a.logger.Error("empty embedding response")
|
||
return nil, err
|
||
}
|
||
|
||
// Collect all embeddings from the response
|
||
embeddings := make([][]float32, len(embResp.Data))
|
||
for i := range embResp.Data {
|
||
if len(embResp.Data[i].Embedding) == 0 {
|
||
err = fmt.Errorf("empty embedding at index %d", i)
|
||
a.logger.Error("empty embedding", "index", i)
|
||
return nil, err
|
||
}
|
||
embeddings[i] = embResp.Data[i].Embedding
|
||
}
|
||
|
||
// Sort embeddings by index to match the order of input lines
|
||
// API responses may not be in order
|
||
for _, data := range embResp.Data {
|
||
if data.Index >= len(embeddings) || data.Index < 0 {
|
||
err = fmt.Errorf("invalid embedding index %d", data.Index)
|
||
a.logger.Error("invalid embedding index", "index", data.Index)
|
||
return nil, err
|
||
}
|
||
embeddings[data.Index] = data.Embedding
|
||
}
|
||
return embeddings, nil
|
||
}
|
||
|
||
// 1. Loading ONNX models locally
|
||
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
||
// 3. Converting text to embeddings without external API calls
|
||
type ONNXEmbedder struct {
|
||
session *onnxruntime_go.DynamicAdvancedSession
|
||
tokenizer *tokenizer.Tokenizer
|
||
dims int // embedding dimension (e.g., 768)
|
||
logger *slog.Logger
|
||
}
|
||
|
||
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
||
// Load tokenizer using sugarme/tokenizer
|
||
tok, err := pretrained.FromFile(tokenizerPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
||
}
|
||
// Create ONNX session
|
||
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
||
modelPath, // onnx/embedgemma/model_q4.onnx
|
||
[]string{"input_ids", "attention_mask"},
|
||
[]string{"sentence_embedding"},
|
||
nil, // optional options
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
|
||
}
|
||
return &ONNXEmbedder{
|
||
session: session,
|
||
tokenizer: tok,
|
||
dims: dims,
|
||
logger: logger,
|
||
}, nil
|
||
}
|
||
|
||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||
// 1. Tokenize
|
||
encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens
|
||
if err != nil {
|
||
return nil, fmt.Errorf("tokenization failed: %w", err)
|
||
}
|
||
// Convert []int32 to []int64 for ONNX
|
||
inputIDs := make([]int64, len(encoding.GetIDs()))
|
||
for i, id := range encoding.GetIDs() {
|
||
inputIDs[i] = int64(id)
|
||
}
|
||
attentionMask := make([]int64, len(encoding.GetAttentionMask()))
|
||
for i, m := range encoding.GetAttentionMask() {
|
||
attentionMask[i] = int64(m)
|
||
}
|
||
// 2. Create input tensors (shape: [1, seq_len])
|
||
seqLen := int64(len(inputIDs))
|
||
inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
||
}
|
||
defer inputIDsTensor.Destroy()
|
||
maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
||
}
|
||
defer maskTensor.Destroy()
|
||
// 3. Create output tensor (shape: [1, dims])
|
||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims)))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
||
}
|
||
defer outputTensor.Destroy()
|
||
// 4. Run inference
|
||
err = e.session.Run(
|
||
map[string]*onnxruntime_go.Tensor{
|
||
"input_ids": inputIDsTensor,
|
||
"attention_mask": maskTensor,
|
||
},
|
||
[]string{"sentence_embedding"},
|
||
[]*onnxruntime_go.Tensor{outputTensor},
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("inference failed: %w", err)
|
||
}
|
||
// 5. Extract data
|
||
outputData := outputTensor.GetData()
|
||
// outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
|
||
// We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
|
||
embedding := make([]float32, len(outputData))
|
||
copy(embedding, outputData)
|
||
return embedding, nil
|
||
}
|
||
|
||
// EmbedSlice (batch) – to be implemented properly
|
||
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
||
if len(texts) == 0 {
|
||
return nil, nil
|
||
}
|
||
// 1. Tokenize all texts and find max length for padding
|
||
encodings := make([]*tokenizer.Encoding, len(texts))
|
||
maxLen := 0
|
||
for i, txt := range texts {
|
||
enc, err := e.tokenizer.Encode(txt, true)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err)
|
||
}
|
||
encodings[i] = enc
|
||
if l := len(enc.GetIDs()); l > maxLen {
|
||
maxLen = l
|
||
}
|
||
}
|
||
// 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
|
||
batchSize := len(texts)
|
||
inputIDs := make([]int64, batchSize*maxLen)
|
||
attentionMask := make([]int64, batchSize*maxLen)
|
||
for i, enc := range encodings {
|
||
ids := enc.GetIDs()
|
||
mask := enc.GetAttentionMask()
|
||
offset := i * maxLen
|
||
// copy actual tokens
|
||
for j := 0; j < len(ids); j++ {
|
||
inputIDs[offset+j] = int64(ids[j])
|
||
attentionMask[offset+j] = int64(mask[j])
|
||
}
|
||
// remaining positions (padding) are already zero-initialized
|
||
}
|
||
// 3. Create tensors
|
||
inputIDsTensor, err := onnxruntime_go.NewTensor(
|
||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||
inputIDs,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer inputIDsTensor.Destroy()
|
||
maskTensor, err := onnxruntime_go.NewTensor(
|
||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||
attentionMask,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer maskTensor.Destroy()
|
||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
||
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer outputTensor.Destroy()
|
||
// 4. Run
|
||
err = e.session.Run(
|
||
map[string]*onnxruntime_go.Tensor{
|
||
"input_ids": inputIDsTensor,
|
||
"attention_mask": maskTensor,
|
||
},
|
||
[]string{"sentence_embedding"},
|
||
[]*onnxruntime_go.Tensor{outputTensor},
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 5. Extract batch results
|
||
outputData := outputTensor.GetData()
|
||
embeddings := make([][]float32, batchSize)
|
||
for i := 0; i < batchSize; i++ {
|
||
start := i * e.dims
|
||
emb := make([]float32, e.dims)
|
||
copy(emb, outputData[start:start+e.dims])
|
||
embeddings[i] = emb
|
||
}
|
||
return embeddings, nil
|
||
}
|