290 lines
8.1 KiB
Go
290 lines
8.1 KiB
Go
package rag
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"gf-lt/config"
|
|
"gf-lt/models"
|
|
"log/slog"
|
|
"net/http"
|
|
|
|
"github.com/sugarme/tokenizer"
|
|
"github.com/sugarme/tokenizer/pretrained"
|
|
"github.com/yalue/onnxruntime_go"
|
|
)
|
|
|
|
// Embedder defines the interface for embedding text
|
|
type Embedder interface {
|
|
Embed(text string) ([]float32, error)
|
|
EmbedSlice(lines []string) ([][]float32, error)
|
|
}
|
|
|
|
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
|
|
type APIEmbedder struct {
|
|
logger *slog.Logger
|
|
client *http.Client
|
|
cfg *config.Config
|
|
}
|
|
|
|
func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
|
|
return &APIEmbedder{
|
|
logger: l,
|
|
client: &http.Client{},
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
func (a *APIEmbedder) Embed(text string) ([]float32, error) {
|
|
payload, err := json.Marshal(
|
|
map[string]any{"input": text, "encoding_format": "float"},
|
|
)
|
|
if err != nil {
|
|
a.logger.Error("failed to marshal payload", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
|
if err != nil {
|
|
a.logger.Error("failed to create new req", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if a.cfg.HFToken != "" {
|
|
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
|
}
|
|
resp, err := a.client.Do(req)
|
|
if err != nil {
|
|
a.logger.Error("failed to embed text", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
|
a.logger.Error(err.Error())
|
|
return nil, err
|
|
}
|
|
embResp := &models.LCPEmbedResp{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
|
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
|
|
err = errors.New("empty embedding response")
|
|
a.logger.Error("empty embedding response")
|
|
return nil, err
|
|
}
|
|
return embResp.Data[0].Embedding, nil
|
|
}
|
|
|
|
func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
|
payload, err := json.Marshal(
|
|
map[string]any{"input": lines, "encoding_format": "float"},
|
|
)
|
|
if err != nil {
|
|
a.logger.Error("failed to marshal payload", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
|
if err != nil {
|
|
a.logger.Error("failed to create new req", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if a.cfg.HFToken != "" {
|
|
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
|
}
|
|
resp, err := a.client.Do(req)
|
|
if err != nil {
|
|
a.logger.Error("failed to embed text", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
|
a.logger.Error(err.Error())
|
|
return nil, err
|
|
}
|
|
embResp := &models.LCPEmbedResp{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
|
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if len(embResp.Data) == 0 {
|
|
err = errors.New("empty embedding response")
|
|
a.logger.Error("empty embedding response")
|
|
return nil, err
|
|
}
|
|
|
|
// Collect all embeddings from the response
|
|
embeddings := make([][]float32, len(embResp.Data))
|
|
for i := range embResp.Data {
|
|
if len(embResp.Data[i].Embedding) == 0 {
|
|
err = fmt.Errorf("empty embedding at index %d", i)
|
|
a.logger.Error("empty embedding", "index", i)
|
|
return nil, err
|
|
}
|
|
embeddings[i] = embResp.Data[i].Embedding
|
|
}
|
|
|
|
// Sort embeddings by index to match the order of input lines
|
|
// API responses may not be in order
|
|
for _, data := range embResp.Data {
|
|
if data.Index >= len(embeddings) || data.Index < 0 {
|
|
err = fmt.Errorf("invalid embedding index %d", data.Index)
|
|
a.logger.Error("invalid embedding index", "index", data.Index)
|
|
return nil, err
|
|
}
|
|
embeddings[data.Index] = data.Embedding
|
|
}
|
|
return embeddings, nil
|
|
}
|
|
|
|
// 1. Loading ONNX models locally
|
|
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
|
// 3. Converting text to embeddings without external API calls
|
|
type ONNXEmbedder struct {
|
|
session *onnxruntime_go.DynamicAdvancedSession
|
|
tokenizer *tokenizer.Tokenizer
|
|
dims int // embedding dimension (e.g., 768)
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
|
// Load tokenizer using sugarme/tokenizer
|
|
tok, err := pretrained.FromFile(tokenizerPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
|
}
|
|
// Create ONNX session
|
|
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
|
modelPath, // onnx/embedgemma/model_q4.onnx
|
|
[]string{"input_ids", "attention_mask"},
|
|
[]string{"sentence_embedding"},
|
|
nil, // optional options
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
|
|
}
|
|
return &ONNXEmbedder{
|
|
session: session,
|
|
tokenizer: tok,
|
|
dims: dims,
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
|
// 1. Tokenize
|
|
encoding, err := e.tokenizer.EncodeSingle(text)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("tokenization failed: %w", err)
|
|
}
|
|
// 2. Convert to int64 and create attention mask
|
|
ids := encoding.Ids
|
|
inputIDs := make([]int64, len(ids))
|
|
attentionMask := make([]int64, len(ids))
|
|
for i, id := range ids {
|
|
inputIDs[i] = int64(id)
|
|
attentionMask[i] = 1
|
|
}
|
|
// 3. Create input tensors (shape: [1, seq_len])
|
|
seqLen := int64(len(inputIDs))
|
|
inputIDsTensor, err := onnxruntime_go.NewTensor[int64](
|
|
onnxruntime_go.NewShape(1, seqLen),
|
|
inputIDs,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
|
}
|
|
defer inputIDsTensor.Destroy()
|
|
maskTensor, err := onnxruntime_go.NewTensor[int64](
|
|
onnxruntime_go.NewShape(1, seqLen),
|
|
attentionMask,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
|
}
|
|
defer maskTensor.Destroy()
|
|
// 4. Create output tensor
|
|
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
|
onnxruntime_go.NewShape(1, int64(e.dims)),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
|
}
|
|
defer outputTensor.Destroy()
|
|
// 5. Run inference
|
|
err = e.session.Run(
|
|
[]onnxruntime_go.Value{inputIDsTensor, maskTensor},
|
|
[]string{"sentence_embedding"},
|
|
[]onnxruntime_go.Value{outputTensor},
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("inference failed: %w", err)
|
|
}
|
|
// 6. Copy output data
|
|
outputData := outputTensor.GetData()
|
|
embedding := make([]float32, len(outputData))
|
|
copy(embedding, outputData)
|
|
return embedding, nil
|
|
}
|
|
|
|
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
|
encodings := make([]*tokenizer.Encoding, len(texts))
|
|
maxLen := 0
|
|
for i, txt := range texts {
|
|
enc, err := e.tokenizer.EncodeSingle(txt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
encodings[i] = enc
|
|
if l := len(enc.Ids); l > maxLen {
|
|
maxLen = l
|
|
}
|
|
}
|
|
batchSize := len(texts)
|
|
inputIDs := make([]int64, batchSize*maxLen)
|
|
attentionMask := make([]int64, batchSize*maxLen)
|
|
for i, enc := range encodings {
|
|
ids := enc.Ids
|
|
offset := i * maxLen
|
|
for j, id := range ids {
|
|
inputIDs[offset+j] = int64(id)
|
|
attentionMask[offset+j] = 1
|
|
}
|
|
// Remaining positions are already zero (padding)
|
|
}
|
|
// Create tensors with shape [batchSize, maxLen]
|
|
inputTensor, _ := onnxruntime_go.NewTensor[int64](
|
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
|
inputIDs,
|
|
)
|
|
defer inputTensor.Destroy()
|
|
maskTensor, _ := onnxruntime_go.NewTensor[int64](
|
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
|
attentionMask,
|
|
)
|
|
defer maskTensor.Destroy()
|
|
outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
|
|
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
|
)
|
|
defer outputTensor.Destroy()
|
|
err := e.session.Run(
|
|
[]onnxruntime_go.Value{inputTensor, maskTensor},
|
|
[]string{"sentence_embedding"},
|
|
[]onnxruntime_go.Value{outputTensor},
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Extract embeddings per batch item
|
|
data := outputTensor.GetData()
|
|
embeddings := make([][]float32, batchSize)
|
|
for i := 0; i < batchSize; i++ {
|
|
start := i * e.dims
|
|
emb := make([]float32, e.dims)
|
|
copy(emb, data[start:start+e.dims])
|
|
embeddings[i] = emb
|
|
}
|
|
return embeddings, nil
|
|
}
|