Enha: embedgemma model
This commit is contained in:
15
models/embed.go
Normal file
15
models/embed.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
type LCPEmbedResp struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
Data []struct {
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
@@ -6,14 +6,15 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gf-lt/config"
|
"gf-lt/config"
|
||||||
|
"gf-lt/models"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Embedder defines the interface for embedding text
|
// Embedder defines the interface for embedding text
|
||||||
type Embedder interface {
|
type Embedder interface {
|
||||||
Embed(text []string) ([][]float32, error)
|
Embed(text string) ([]float32, error)
|
||||||
EmbedSingle(text string) ([]float32, error)
|
EmbedSlice(lines []string) ([][]float32, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
|
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
|
||||||
@@ -31,62 +32,107 @@ func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *APIEmbedder) Embed(text []string) ([][]float32, error) {
|
func (a *APIEmbedder) Embed(text string) ([]float32, error) {
|
||||||
payload, err := json.Marshal(
|
payload, err := json.Marshal(
|
||||||
map[string]any{"inputs": text, "options": map[string]bool{"wait_for_model": true}},
|
map[string]any{"input": text, "encoding_format": "float"},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Error("failed to marshal payload", "err", err.Error())
|
a.logger.Error("failed to marshal payload", "err", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Error("failed to create new req", "err", err.Error())
|
a.logger.Error("failed to create new req", "err", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if a.cfg.HFToken != "" {
|
if a.cfg.HFToken != "" {
|
||||||
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := a.client.Do(req)
|
resp, err := a.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Error("failed to embed text", "err", err.Error())
|
a.logger.Error("failed to embed text", "err", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
||||||
a.logger.Error(err.Error())
|
a.logger.Error(err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
embResp := &models.LCPEmbedResp{}
|
||||||
var emb [][]float32
|
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
|
||||||
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
|
||||||
|
err = errors.New("empty embedding response")
|
||||||
|
a.logger.Error("empty embedding response")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return embResp.Data[0].Embedding, nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(emb) == 0 {
|
func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
||||||
|
payload, err := json.Marshal(
|
||||||
|
map[string]any{"input": lines, "encoding_format": "float"},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error("failed to marshal payload", "err", err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error("failed to create new req", "err", err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if a.cfg.HFToken != "" {
|
||||||
|
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
|
||||||
|
}
|
||||||
|
resp, err := a.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error("failed to embed text", "err", err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
|
||||||
|
a.logger.Error(err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
embResp := &models.LCPEmbedResp{}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
|
||||||
|
a.logger.Error("failed to decode embedding response", "err", err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(embResp.Data) == 0 {
|
||||||
err = errors.New("empty embedding response")
|
err = errors.New("empty embedding response")
|
||||||
a.logger.Error("empty embedding response")
|
a.logger.Error("empty embedding response")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return emb, nil
|
// Collect all embeddings from the response
|
||||||
}
|
embeddings := make([][]float32, len(embResp.Data))
|
||||||
|
for i := range embResp.Data {
|
||||||
func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
|
if len(embResp.Data[i].Embedding) == 0 {
|
||||||
result, err := a.Embed([]string{text})
|
err = fmt.Errorf("empty embedding at index %d", i)
|
||||||
if err != nil {
|
a.logger.Error("empty embedding", "index", i)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(result) == 0 {
|
embeddings[i] = embResp.Data[i].Embedding
|
||||||
return nil, errors.New("no embeddings returned")
|
|
||||||
}
|
}
|
||||||
return result[0], nil
|
|
||||||
|
// Sort embeddings by index to match the order of input lines
|
||||||
|
// API responses may not be in order
|
||||||
|
for _, data := range embResp.Data {
|
||||||
|
if data.Index >= len(embeddings) || data.Index < 0 {
|
||||||
|
err = fmt.Errorf("invalid embedding index %d", data.Index)
|
||||||
|
a.logger.Error("invalid embedding index", "index", data.Index)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
embeddings[data.Index] = data.Embedding
|
||||||
|
}
|
||||||
|
|
||||||
|
return embeddings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: ONNXEmbedder implementation would go here
|
// TODO: ONNXEmbedder implementation would go here
|
||||||
@@ -97,4 +143,3 @@ func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
|
|||||||
//
|
//
|
||||||
// For now, we'll focus on the API implementation which is already working in the current system,
|
// For now, we'll focus on the API implementation which is already working in the current system,
|
||||||
// and can be extended later when we have ONNX runtime integration
|
// and can be extended later when we have ONNX runtime integration
|
||||||
|
|
||||||
|
|||||||
50
rag/rag.go
50
rag/rag.go
@@ -148,10 +148,12 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
||||||
go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for embedding to be done
|
// Wait for embedding to be done
|
||||||
<-doneCh
|
<-doneCh
|
||||||
|
err = <-errCh
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// Write vectors to storage
|
// Write vectors to storage
|
||||||
return r.writeVectors(vectorCh)
|
return r.writeVectors(vectorCh)
|
||||||
}
|
}
|
||||||
@@ -178,9 +180,11 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
|||||||
|
|
||||||
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
||||||
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
|
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
|
||||||
|
var err error
|
||||||
defer func() {
|
defer func() {
|
||||||
if len(doneCh) == 0 {
|
if len(doneCh) == 0 {
|
||||||
doneCh <- true
|
doneCh <- true
|
||||||
|
errCh <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -201,7 +205,7 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
lock.Unlock()
|
lock.Unlock()
|
||||||
case err := <-errCh:
|
case err = <-errCh:
|
||||||
r.logger.Error("got an error from error channel", "error", err)
|
r.logger.Error("got an error from error channel", "error", err)
|
||||||
lock.Unlock()
|
lock.Unlock()
|
||||||
return
|
return
|
||||||
@@ -215,7 +219,23 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
|
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
|
||||||
embeddings, err := r.embedder.Embed(lines)
|
// Filter out empty lines before sending to embedder
|
||||||
|
nonEmptyLines := make([]string, 0, len(lines))
|
||||||
|
for _, line := range lines {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if trimmed != "" {
|
||||||
|
nonEmptyLines = append(nonEmptyLines, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if no non-empty lines
|
||||||
|
if len(nonEmptyLines) == 0 {
|
||||||
|
// Send empty result but don't error
|
||||||
|
vectorCh <- []models.VectorRow{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to embed lines", "err", err.Error())
|
r.logger.Error("failed to embed lines", "err", err.Error())
|
||||||
errCh <- err
|
errCh <- err
|
||||||
@@ -229,15 +249,22 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
vectors := make([]models.VectorRow, len(embeddings))
|
if len(embeddings) != len(nonEmptyLines) {
|
||||||
for i, emb := range embeddings {
|
err := errors.New("mismatch between number of lines and embeddings returned")
|
||||||
vector := models.VectorRow{
|
r.logger.Error("embedding mismatch", "err", err.Error())
|
||||||
Embeddings: emb,
|
errCh <- err
|
||||||
RawText: lines[i],
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a VectorRow for each line in the batch
|
||||||
|
vectors := make([]models.VectorRow, len(nonEmptyLines))
|
||||||
|
for i, line := range nonEmptyLines {
|
||||||
|
vectors[i] = models.VectorRow{
|
||||||
|
Embeddings: embeddings[i],
|
||||||
|
RawText: line,
|
||||||
Slug: fmt.Sprintf("%s_%d", slug, i),
|
Slug: fmt.Sprintf("%s_%d", slug, i),
|
||||||
FileName: filename,
|
FileName: filename,
|
||||||
}
|
}
|
||||||
vectors[i] = vector
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vectorCh <- vectors
|
vectorCh <- vectors
|
||||||
@@ -245,7 +272,7 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||||
return r.embedder.EmbedSingle(line)
|
return r.embedder.Embed(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
||||||
@@ -259,4 +286,3 @@ func (r *RAG) ListLoaded() ([]string, error) {
|
|||||||
func (r *RAG) RemoveFile(filename string) error {
|
func (r *RAG) RemoveFile(filename string) error {
|
||||||
return r.storage.RemoveEmbByFileName(filename)
|
return r.storage.RemoveEmbByFileName(filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ func makeRAGTable(fileList []string) *tview.Flex {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1) // why?
|
||||||
go func() {
|
go func() {
|
||||||
defer pages.RemovePage(RAGPage)
|
defer pages.RemovePage(RAGPage)
|
||||||
for {
|
for {
|
||||||
@@ -273,6 +273,7 @@ func makeRAGTable(fileList []string) *tview.Flex {
|
|||||||
go func() {
|
go func() {
|
||||||
if err := ragger.LoadRAG(fpath); err != nil {
|
if err := ragger.LoadRAG(fpath); err != nil {
|
||||||
logger.Error("failed to embed file", "chat", fpath, "error", err)
|
logger.Error("failed to embed file", "chat", fpath, "error", err)
|
||||||
|
_ = notifyUser("RAG", "failed to embed file; error: "+err.Error())
|
||||||
errCh <- err
|
errCh <- err
|
||||||
// pages.RemovePage(RAGPage)
|
// pages.RemovePage(RAGPage)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user