Enha: embedgemma model
This commit is contained in:
50
rag/rag.go
50
rag/rag.go
@@ -148,10 +148,12 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
||||
go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
||||
}
|
||||
|
||||
// Wait for embedding to be done
|
||||
<-doneCh
|
||||
|
||||
err = <-errCh
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Write vectors to storage
|
||||
return r.writeVectors(vectorCh)
|
||||
}
|
||||
@@ -178,9 +180,11 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
||||
|
||||
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
||||
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
|
||||
var err error
|
||||
defer func() {
|
||||
if len(doneCh) == 0 {
|
||||
doneCh <- true
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -201,7 +205,7 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
|
||||
}
|
||||
}
|
||||
lock.Unlock()
|
||||
case err := <-errCh:
|
||||
case err = <-errCh:
|
||||
r.logger.Error("got an error from error channel", "error", err)
|
||||
lock.Unlock()
|
||||
return
|
||||
@@ -215,7 +219,23 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
|
||||
}
|
||||
|
||||
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
|
||||
embeddings, err := r.embedder.Embed(lines)
|
||||
// Filter out empty lines before sending to embedder
|
||||
nonEmptyLines := make([]string, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed != "" {
|
||||
nonEmptyLines = append(nonEmptyLines, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if no non-empty lines
|
||||
if len(nonEmptyLines) == 0 {
|
||||
// Send empty result but don't error
|
||||
vectorCh <- []models.VectorRow{}
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to embed lines", "err", err.Error())
|
||||
errCh <- err
|
||||
@@ -229,15 +249,22 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
|
||||
return err
|
||||
}
|
||||
|
||||
vectors := make([]models.VectorRow, len(embeddings))
|
||||
for i, emb := range embeddings {
|
||||
vector := models.VectorRow{
|
||||
Embeddings: emb,
|
||||
RawText: lines[i],
|
||||
if len(embeddings) != len(nonEmptyLines) {
|
||||
err := errors.New("mismatch between number of lines and embeddings returned")
|
||||
r.logger.Error("embedding mismatch", "err", err.Error())
|
||||
errCh <- err
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a VectorRow for each line in the batch
|
||||
vectors := make([]models.VectorRow, len(nonEmptyLines))
|
||||
for i, line := range nonEmptyLines {
|
||||
vectors[i] = models.VectorRow{
|
||||
Embeddings: embeddings[i],
|
||||
RawText: line,
|
||||
Slug: fmt.Sprintf("%s_%d", slug, i),
|
||||
FileName: filename,
|
||||
}
|
||||
vectors[i] = vector
|
||||
}
|
||||
|
||||
vectorCh <- vectors
|
||||
@@ -245,7 +272,7 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
|
||||
}
|
||||
|
||||
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||
return r.embedder.EmbedSingle(line)
|
||||
return r.embedder.Embed(line)
|
||||
}
|
||||
|
||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
||||
@@ -259,4 +286,3 @@ func (r *RAG) ListLoaded() ([]string, error) {
|
||||
func (r *RAG) RemoveFile(filename string) error {
|
||||
return r.storage.RemoveEmbByFileName(filename)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user