Enha: embedgemma model

This commit is contained in:
Grail Finder
2025-11-22 14:56:24 +03:00
parent 5fe03fa66c
commit 50d7bfced3
4 changed files with 123 additions and 36 deletions

15
models/embed.go Normal file
View File

@@ -0,0 +1,15 @@
package models
type LCPEmbedResp struct {
Model string `json:"model"`
Object string `json:"object"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Data []struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
} `json:"data"`
}

View File

@@ -6,14 +6,15 @@ import (
"errors" "errors"
"fmt" "fmt"
"gf-lt/config" "gf-lt/config"
"gf-lt/models"
"log/slog" "log/slog"
"net/http" "net/http"
) )
// Embedder defines the interface for embedding text // Embedder defines the interface for embedding text
type Embedder interface { type Embedder interface {
Embed(text []string) ([][]float32, error) Embed(text string) ([]float32, error)
EmbedSingle(text string) ([]float32, error) EmbedSlice(lines []string) ([][]float32, error)
} }
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.) // APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
@@ -31,62 +32,107 @@ func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
} }
} }
func (a *APIEmbedder) Embed(text []string) ([][]float32, error) { func (a *APIEmbedder) Embed(text string) ([]float32, error) {
payload, err := json.Marshal( payload, err := json.Marshal(
map[string]any{"inputs": text, "options": map[string]bool{"wait_for_model": true}}, map[string]any{"input": text, "encoding_format": "float"},
) )
if err != nil { if err != nil {
a.logger.Error("failed to marshal payload", "err", err.Error()) a.logger.Error("failed to marshal payload", "err", err.Error())
return nil, err return nil, err
} }
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
if err != nil { if err != nil {
a.logger.Error("failed to create new req", "err", err.Error()) a.logger.Error("failed to create new req", "err", err.Error())
return nil, err return nil, err
} }
if a.cfg.HFToken != "" { if a.cfg.HFToken != "" {
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
} }
resp, err := a.client.Do(req) resp, err := a.client.Do(req)
if err != nil { if err != nil {
a.logger.Error("failed to embed text", "err", err.Error()) a.logger.Error("failed to embed text", "err", err.Error())
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
a.logger.Error(err.Error()) a.logger.Error(err.Error())
return nil, err return nil, err
} }
embResp := &models.LCPEmbedResp{}
var emb [][]float32 if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
a.logger.Error("failed to decode embedding response", "err", err.Error()) a.logger.Error("failed to decode embedding response", "err", err.Error())
return nil, err return nil, err
} }
if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
err = errors.New("empty embedding response")
a.logger.Error("empty embedding response")
return nil, err
}
return embResp.Data[0].Embedding, nil
}
if len(emb) == 0 { func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
payload, err := json.Marshal(
map[string]any{"input": lines, "encoding_format": "float"},
)
if err != nil {
a.logger.Error("failed to marshal payload", "err", err.Error())
return nil, err
}
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
if err != nil {
a.logger.Error("failed to create new req", "err", err.Error())
return nil, err
}
if a.cfg.HFToken != "" {
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
}
resp, err := a.client.Do(req)
if err != nil {
a.logger.Error("failed to embed text", "err", err.Error())
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
a.logger.Error(err.Error())
return nil, err
}
embResp := &models.LCPEmbedResp{}
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
a.logger.Error("failed to decode embedding response", "err", err.Error())
return nil, err
}
if len(embResp.Data) == 0 {
err = errors.New("empty embedding response") err = errors.New("empty embedding response")
a.logger.Error("empty embedding response") a.logger.Error("empty embedding response")
return nil, err return nil, err
} }
return emb, nil // Collect all embeddings from the response
} embeddings := make([][]float32, len(embResp.Data))
for i := range embResp.Data {
if len(embResp.Data[i].Embedding) == 0 {
err = fmt.Errorf("empty embedding at index %d", i)
a.logger.Error("empty embedding", "index", i)
return nil, err
}
embeddings[i] = embResp.Data[i].Embedding
}
func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) { // Sort embeddings by index to match the order of input lines
result, err := a.Embed([]string{text}) // API responses may not be in order
if err != nil { for _, data := range embResp.Data {
return nil, err if data.Index >= len(embeddings) || data.Index < 0 {
err = fmt.Errorf("invalid embedding index %d", data.Index)
a.logger.Error("invalid embedding index", "index", data.Index)
return nil, err
}
embeddings[data.Index] = data.Embedding
} }
if len(result) == 0 {
return nil, errors.New("no embeddings returned") return embeddings, nil
}
return result[0], nil
} }
// TODO: ONNXEmbedder implementation would go here // TODO: ONNXEmbedder implementation would go here
@@ -97,4 +143,3 @@ func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
// //
// For now, we'll focus on the API implementation which is already working in the current system, // For now, we'll focus on the API implementation which is already working in the current system,
// and can be extended later when we have ONNX runtime integration // and can be extended later when we have ONNX runtime integration

View File

@@ -148,10 +148,12 @@ func (r *RAG) LoadRAG(fpath string) error {
for w := 0; w < int(r.cfg.RAGWorkers); w++ { for w := 0; w < int(r.cfg.RAGWorkers); w++ {
go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
} }
// Wait for embedding to be done // Wait for embedding to be done
<-doneCh <-doneCh
err = <-errCh
if err != nil {
return err
}
// Write vectors to storage // Write vectors to storage
return r.writeVectors(vectorCh) return r.writeVectors(vectorCh)
} }
@@ -178,9 +180,11 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
var err error
defer func() { defer func() {
if len(doneCh) == 0 { if len(doneCh) == 0 {
doneCh <- true doneCh <- true
errCh <- err
} }
}() }()
@@ -201,7 +205,7 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
} }
} }
lock.Unlock() lock.Unlock()
case err := <-errCh: case err = <-errCh:
r.logger.Error("got an error from error channel", "error", err) r.logger.Error("got an error from error channel", "error", err)
lock.Unlock() lock.Unlock()
return return
@@ -215,7 +219,23 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
} }
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error { func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
embeddings, err := r.embedder.Embed(lines) // Filter out empty lines before sending to embedder
nonEmptyLines := make([]string, 0, len(lines))
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed != "" {
nonEmptyLines = append(nonEmptyLines, trimmed)
}
}
// Skip if no non-empty lines
if len(nonEmptyLines) == 0 {
// Send empty result but don't error
vectorCh <- []models.VectorRow{}
return nil
}
embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
if err != nil { if err != nil {
r.logger.Error("failed to embed lines", "err", err.Error()) r.logger.Error("failed to embed lines", "err", err.Error())
errCh <- err errCh <- err
@@ -229,15 +249,22 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
return err return err
} }
vectors := make([]models.VectorRow, len(embeddings)) if len(embeddings) != len(nonEmptyLines) {
for i, emb := range embeddings { err := errors.New("mismatch between number of lines and embeddings returned")
vector := models.VectorRow{ r.logger.Error("embedding mismatch", "err", err.Error())
Embeddings: emb, errCh <- err
RawText: lines[i], return err
}
// Create a VectorRow for each line in the batch
vectors := make([]models.VectorRow, len(nonEmptyLines))
for i, line := range nonEmptyLines {
vectors[i] = models.VectorRow{
Embeddings: embeddings[i],
RawText: line,
Slug: fmt.Sprintf("%s_%d", slug, i), Slug: fmt.Sprintf("%s_%d", slug, i),
FileName: filename, FileName: filename,
} }
vectors[i] = vector
} }
vectorCh <- vectors vectorCh <- vectors
@@ -245,7 +272,7 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
} }
func (r *RAG) LineToVector(line string) ([]float32, error) { func (r *RAG) LineToVector(line string) ([]float32, error) {
return r.embedder.EmbedSingle(line) return r.embedder.Embed(line)
} }
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
@@ -259,4 +286,3 @@ func (r *RAG) ListLoaded() ([]string, error) {
func (r *RAG) RemoveFile(filename string) error { func (r *RAG) RemoveFile(filename string) error {
return r.storage.RemoveEmbByFileName(filename) return r.storage.RemoveEmbByFileName(filename)
} }

View File

@@ -216,7 +216,7 @@ func makeRAGTable(fileList []string) *tview.Flex {
} }
} }
} }
errCh := make(chan error, 1) errCh := make(chan error, 1) // why?
go func() { go func() {
defer pages.RemovePage(RAGPage) defer pages.RemovePage(RAGPage)
for { for {
@@ -273,6 +273,7 @@ func makeRAGTable(fileList []string) *tview.Flex {
go func() { go func() {
if err := ragger.LoadRAG(fpath); err != nil { if err := ragger.LoadRAG(fpath); err != nil {
logger.Error("failed to embed file", "chat", fpath, "error", err) logger.Error("failed to embed file", "chat", fpath, "error", err)
_ = notifyUser("RAG", "failed to embed file; error: "+err.Error())
errCh <- err errCh <- err
// pages.RemovePage(RAGPage) // pages.RemovePage(RAGPage)
return return