Enha: embedgemma model
This commit is contained in:
@@ -6,14 +6,15 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gf-lt/config"
|
||||
"gf-lt/models"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Embedder defines the interface for embedding text
|
||||
type Embedder interface {
|
||||
Embed(text []string) ([][]float32, error)
|
||||
EmbedSingle(text string) ([]float32, error)
|
||||
Embed(text string) ([]float32, error)
|
||||
EmbedSlice(lines []string) ([][]float32, error)
|
||||
}
|
||||
|
||||
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
|
||||
@@ -31,62 +32,107 @@ func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *APIEmbedder) Embed(text []string) ([][]float32, error) {
|
||||
func (a *APIEmbedder) Embed(text string) ([]float32, error) {
|
||||
payload, err := json.Marshal(
|
||||
map[string]any{"inputs": text, "options": map[string]bool{"wait_for_model": true}},
|
||||
map[string]any{"input": text, "encoding_format": "float"},
|
||||
)
|
||||
if err != nil {
|
||||
a.logger.Error("failed to marshal payload", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
a.logger.Error("failed to create new req", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.cfg.HFToken != "" {
|
||||
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
||||
}
|
||||
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
a.logger.Error("failed to embed text", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
||||
a.logger.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var emb [][]float32
|
||||
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
||||
embResp := &models.LCPEmbedResp{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
||||
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
|
||||
err = errors.New("empty embedding response")
|
||||
a.logger.Error("empty embedding response")
|
||||
return nil, err
|
||||
}
|
||||
return embResp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
if len(emb) == 0 {
|
||||
func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
||||
payload, err := json.Marshal(
|
||||
map[string]any{"input": lines, "encoding_format": "float"},
|
||||
)
|
||||
if err != nil {
|
||||
a.logger.Error("failed to marshal payload", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
a.logger.Error("failed to create new req", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if a.cfg.HFToken != "" {
|
||||
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
||||
}
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
a.logger.Error("failed to embed text", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
||||
a.logger.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
embResp := &models.LCPEmbedResp{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
||||
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if len(embResp.Data) == 0 {
|
||||
err = errors.New("empty embedding response")
|
||||
a.logger.Error("empty embedding response")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
// Collect all embeddings from the response
|
||||
embeddings := make([][]float32, len(embResp.Data))
|
||||
for i := range embResp.Data {
|
||||
if len(embResp.Data[i].Embedding) == 0 {
|
||||
err = fmt.Errorf("empty embedding at index %d", i)
|
||||
a.logger.Error("empty embedding", "index", i)
|
||||
return nil, err
|
||||
}
|
||||
embeddings[i] = embResp.Data[i].Embedding
|
||||
}
|
||||
|
||||
func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
|
||||
result, err := a.Embed([]string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Sort embeddings by index to match the order of input lines
|
||||
// API responses may not be in order
|
||||
for _, data := range embResp.Data {
|
||||
if data.Index >= len(embeddings) || data.Index < 0 {
|
||||
err = fmt.Errorf("invalid embedding index %d", data.Index)
|
||||
a.logger.Error("invalid embedding index", "index", data.Index)
|
||||
return nil, err
|
||||
}
|
||||
embeddings[data.Index] = data.Embedding
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errors.New("no embeddings returned")
|
||||
}
|
||||
return result[0], nil
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// TODO: ONNXEmbedder implementation would go here
|
||||
@@ -97,4 +143,3 @@ func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
|
||||
//
|
||||
// For now, we'll focus on the API implementation which is already working in the current system,
|
||||
// and can be extended later when we have ONNX runtime integration
|
||||
|
||||
|
||||
Reference in New Issue
Block a user