200 lines
5.4 KiB
Go
200 lines
5.4 KiB
Go
package rag
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"gf-lt/config"
|
|
"gf-lt/models"
|
|
"log/slog"
|
|
"net/http"
|
|
|
|
"github.com/takara-ai/go-tokenizers/tokenizers"
|
|
|
|
"github.com/yalue/onnxruntime_go"
|
|
)
|
|
|
|
// Embedder defines the interface for embedding text
|
|
type Embedder interface {
|
|
Embed(text string) ([]float32, error)
|
|
EmbedSlice(lines []string) ([][]float32, error)
|
|
}
|
|
|
|
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
|
|
type APIEmbedder struct {
|
|
logger *slog.Logger
|
|
client *http.Client
|
|
cfg *config.Config
|
|
}
|
|
|
|
func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
|
|
return &APIEmbedder{
|
|
logger: l,
|
|
client: &http.Client{},
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
func (a *APIEmbedder) Embed(text string) ([]float32, error) {
|
|
payload, err := json.Marshal(
|
|
map[string]any{"input": text, "encoding_format": "float"},
|
|
)
|
|
if err != nil {
|
|
a.logger.Error("failed to marshal payload", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
|
if err != nil {
|
|
a.logger.Error("failed to create new req", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if a.cfg.HFToken != "" {
|
|
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
|
}
|
|
resp, err := a.client.Do(req)
|
|
if err != nil {
|
|
a.logger.Error("failed to embed text", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
|
a.logger.Error(err.Error())
|
|
return nil, err
|
|
}
|
|
embResp := &models.LCPEmbedResp{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
|
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
|
|
err = errors.New("empty embedding response")
|
|
a.logger.Error("empty embedding response")
|
|
return nil, err
|
|
}
|
|
return embResp.Data[0].Embedding, nil
|
|
}
|
|
|
|
func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
|
payload, err := json.Marshal(
|
|
map[string]any{"input": lines, "encoding_format": "float"},
|
|
)
|
|
if err != nil {
|
|
a.logger.Error("failed to marshal payload", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
|
if err != nil {
|
|
a.logger.Error("failed to create new req", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if a.cfg.HFToken != "" {
|
|
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
|
}
|
|
resp, err := a.client.Do(req)
|
|
if err != nil {
|
|
a.logger.Error("failed to embed text", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
|
a.logger.Error(err.Error())
|
|
return nil, err
|
|
}
|
|
embResp := &models.LCPEmbedResp{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
|
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
|
return nil, err
|
|
}
|
|
if len(embResp.Data) == 0 {
|
|
err = errors.New("empty embedding response")
|
|
a.logger.Error("empty embedding response")
|
|
return nil, err
|
|
}
|
|
|
|
// Collect all embeddings from the response
|
|
embeddings := make([][]float32, len(embResp.Data))
|
|
for i := range embResp.Data {
|
|
if len(embResp.Data[i].Embedding) == 0 {
|
|
err = fmt.Errorf("empty embedding at index %d", i)
|
|
a.logger.Error("empty embedding", "index", i)
|
|
return nil, err
|
|
}
|
|
embeddings[i] = embResp.Data[i].Embedding
|
|
}
|
|
|
|
// Sort embeddings by index to match the order of input lines
|
|
// API responses may not be in order
|
|
for _, data := range embResp.Data {
|
|
if data.Index >= len(embeddings) || data.Index < 0 {
|
|
err = fmt.Errorf("invalid embedding index %d", data.Index)
|
|
a.logger.Error("invalid embedding index", "index", data.Index)
|
|
return nil, err
|
|
}
|
|
embeddings[data.Index] = data.Embedding
|
|
}
|
|
return embeddings, nil
|
|
}
|
|
|
|
// 1. Loading ONNX models locally
|
|
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
|
// 3. Converting text to embeddings without external API calls
|
|
|
|
type ONNXEmbedder struct {
|
|
session *onnxruntime_go.DynamicAdvancedSession
|
|
tokenizer *tokenizers.Tokenizer
|
|
dims int // 768, 512, 256, or 128 for Matryoshka
|
|
}
|
|
|
|
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
|
// Batch processing
|
|
inputs := e.prepareBatch(texts)
|
|
outputs := make([][]float32, len(texts))
|
|
|
|
// Run batch inference (much faster)
|
|
err := e.session.Run(inputs, outputs)
|
|
return outputs, err
|
|
}
|
|
|
|
func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) {
|
|
// Load ONNX model
|
|
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
|
modelPath, // onnx/embedgemma/model_q4.onnx
|
|
[]string{"input_ids", "attention_mask"},
|
|
[]string{"sentence_embedding"},
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Load tokenizer (from Hugging Face)
|
|
tokenizer, err := tokenizers.FromFile("./tokenizer.json")
|
|
return &ONNXEmbedder{
|
|
session: session,
|
|
tokenizer: tokenizer,
|
|
}, nil
|
|
}
|
|
|
|
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
|
// Tokenize
|
|
tokens := e.tokenizer.Encode(text, true)
|
|
// Prepare inputs
|
|
inputIDs := []int64{tokens.GetIds()}
|
|
attentionMask := []int64{tokens.GetAttentionMask()}
|
|
// Run inference
|
|
output := onnxruntime_go.NewEmptyTensor[float32](
|
|
onnxruntime_go.NewShape(1, 768),
|
|
)
|
|
err := e.session.Run(
|
|
map[string]any{
|
|
"input_ids": inputIDs,
|
|
"attention_mask": attentionMask,
|
|
},
|
|
[]string{"sentence_embedding"},
|
|
[]any{&output},
|
|
)
|
|
return output.GetData(), nil
|
|
}
|