Enha (rag): singlethred

This commit is contained in:
Grail Finder
2026-02-24 15:28:18 +03:00
parent 34cd4ac141
commit 78059083c2
2 changed files with 68 additions and 207 deletions

View File

@@ -29,6 +29,7 @@ type RAG struct {
cfg *config.Config cfg *config.Config
embedder Embedder embedder Embedder
storage *VectorStorage storage *VectorStorage
mu sync.Mutex
} }
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
@@ -53,6 +54,8 @@ func wordCounter(sentence string) int {
} }
func (r *RAG) LoadRAG(fpath string) error { func (r *RAG) LoadRAG(fpath string) error {
r.mu.Lock()
defer r.mu.Unlock()
data, err := os.ReadFile(fpath) data, err := os.ReadFile(fpath)
if err != nil { if err != nil {
return err return err
@@ -62,9 +65,7 @@ func (r *RAG) LoadRAG(fpath string) error {
case LongJobStatusCh <- LoadedFileRAGStatus: case LongJobStatusCh <- LoadedFileRAGStatus:
default: default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus) r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
// Channel is full or closed, ignore the message to prevent panic
} }
fileText := string(data) fileText := string(data)
tokenizer, err := english.NewSentenceTokenizer(nil) tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil { if err != nil {
@@ -75,19 +76,16 @@ func (r *RAG) LoadRAG(fpath string) error {
for i, s := range sentences { for i, s := range sentences {
sents[i] = s.Text sents[i] = s.Text
} }
// Group sentences into paragraphs based on word limit // Group sentences into paragraphs based on word limit
paragraphs := []string{} paragraphs := []string{}
par := strings.Builder{} par := strings.Builder{}
for i := 0; i < len(sents); i++ { for i := 0; i < len(sents); i++ {
// Only add sentences that aren't empty
if strings.TrimSpace(sents[i]) != "" { if strings.TrimSpace(sents[i]) != "" {
if par.Len() > 0 { if par.Len() > 0 {
par.WriteString(" ") // Add space between sentences par.WriteString(" ")
} }
par.WriteString(sents[i]) par.WriteString(sents[i])
} }
if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
paragraph := strings.TrimSpace(par.String()) paragraph := strings.TrimSpace(par.String())
if paragraph != "" { if paragraph != "" {
@@ -96,7 +94,6 @@ func (r *RAG) LoadRAG(fpath string) error {
par.Reset() par.Reset()
} }
} }
// Handle any remaining content in the paragraph buffer // Handle any remaining content in the paragraph buffer
if par.Len() > 0 { if par.Len() > 0 {
paragraph := strings.TrimSpace(par.String()) paragraph := strings.TrimSpace(par.String())
@@ -104,217 +101,82 @@ func (r *RAG) LoadRAG(fpath string) error {
paragraphs = append(paragraphs, paragraph) paragraphs = append(paragraphs, paragraph)
} }
} }
// Adjust batch size if needed // Adjust batch size if needed
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 { if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
r.cfg.RAGBatchSize = len(paragraphs) r.cfg.RAGBatchSize = len(paragraphs)
} }
if len(paragraphs) == 0 { if len(paragraphs) == 0 {
return errors.New("no valid paragraphs found in file") return errors.New("no valid paragraphs found in file")
} }
// Process paragraphs in batches synchronously
var ( batchCount := 0
maxChSize = 100 for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize {
left = 0 end := i + r.cfg.RAGBatchSize
right = r.cfg.RAGBatchSize if end > len(paragraphs) {
batchCh = make(chan map[int][]string, maxChSize) end = len(paragraphs)
vectorCh = make(chan []models.VectorRow, maxChSize)
errCh = make(chan error, 1)
doneCh = make(chan struct{})
wg = new(sync.WaitGroup)
)
defer close(doneCh)
defer close(errCh)
defer close(batchCh)
// Fill input channel with batches
ctn := 0
totalParagraphs := len(paragraphs)
for {
if right > totalParagraphs {
batchCh <- map[int][]string{left: paragraphs[left:]}
break
} }
batchCh <- map[int][]string{left: paragraphs[left:right]} batch := paragraphs[i:end]
left, right = right, right+r.cfg.RAGBatchSize batchCount++
ctn++ // Filter empty paragraphs
} nonEmptyBatch := make([]string, 0, len(batch))
for _, p := range batch {
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents)) if strings.TrimSpace(p) != "" {
r.logger.Debug(finishedBatchesMsg) nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
select { }
case LongJobStatusCh <- finishedBatchesMsg: }
default: if len(nonEmptyBatch) == 0 {
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg) continue
// Channel is full or closed, ignore the message to prevent panic }
} // Embed the batch
embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch)
// Start worker goroutines with WaitGroup
wg.Add(int(r.cfg.RAGWorkers))
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
go func(workerID int) {
defer wg.Done()
r.batchToVectorAsync(workerID, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
}(w)
}
// Close batchCh to signal workers no more data is coming
close(batchCh)
// Wait for all workers to finish, then close vectorCh
go func() {
wg.Wait()
close(vectorCh)
}()
// Check for errors from workers - this will block until an error occurs or all workers finish
select {
case err := <-errCh:
if err != nil { if err != nil {
r.logger.Error("error during RAG processing", "error", err) r.logger.Error("failed to embed batch", "error", err, "batch", batchCount)
select {
case LongJobStatusCh <- ErrRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
return fmt.Errorf("failed to embed batch %d: %w", batchCount, err)
}
if len(embeddings) != len(nonEmptyBatch) {
err := errors.New("embedding count mismatch")
r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings))
return err return err
} }
// Write vectors to storage
filename := path.Base(fpath)
for j, text := range nonEmptyBatch {
vector := models.VectorRow{
Embeddings: embeddings[j],
RawText: text,
Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j),
FileName: filename,
}
if err := r.storage.WriteVector(&vector); err != nil {
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
select {
case LongJobStatusCh <- ErrRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
return fmt.Errorf("failed to write vector: %w", err)
}
}
r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch))
// Send progress status
statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize)
select {
case LongJobStatusCh <- statusMsg:
default:
r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
}
r.logger.Debug("finished writing vectors", "batches", batchCount)
select {
case LongJobStatusCh <- FinishedRAGStatus:
default: default:
// No immediate error, continue r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
} }
// Write vectors to storage - this will block until vectorCh is closed
return r.writeVectors(vectorCh, errCh)
}
func (r *RAG) writeVectors(vectorCh chan []models.VectorRow, errCh chan error) error {
// Use a select to handle both vectorCh and errCh
for {
select {
case err := <-errCh:
if err != nil {
r.logger.Error("error during RAG processing in writeVectors", "error", err)
return err
}
case batch, ok := <-vectorCh:
if !ok {
r.logger.Debug("vector channel closed, finished writing vectors")
select {
case LongJobStatusCh <- FinishedRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
}
return nil
}
for _, vector := range batch {
if err := r.storage.WriteVector(&vector); err != nil {
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
select {
case LongJobStatusCh <- ErrRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
}
return err
}
}
r.logger.Debug("wrote batch to db", "size", len(batch))
}
}
}
func (r *RAG) batchToVectorAsync(id int, inputCh <-chan map[int][]string,
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh <-chan struct{}, filename string) {
var err error
defer func() {
if err != nil {
select {
case errCh <- err:
default:
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
}
}
}()
for {
select {
case <-doneCh:
r.logger.Debug("worker received done signal", "worker", id)
return
case linesMap, ok := <-inputCh:
if !ok {
r.logger.Debug("input channel closed, worker exiting", "worker", id)
return
}
for leftI, lines := range linesMap {
select {
case <-doneCh:
return
default:
}
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
r.logger.Error("error fetching embeddings", "error", err, "worker", id)
return
}
}
r.logger.Debug("processed batch", "worker#", id)
statusMsg := fmt.Sprintf("converted to vector; worker#: %d", id)
select {
case LongJobStatusCh <- statusMsg:
default:
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
}
}
}
}
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
// Filter out empty lines before sending to embedder
nonEmptyLines := make([]string, 0, len(lines))
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed != "" {
nonEmptyLines = append(nonEmptyLines, trimmed)
}
}
// Skip if no non-empty lines
if len(nonEmptyLines) == 0 {
// Send empty result but don't error
vectorCh <- []models.VectorRow{}
return nil
}
embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
if err != nil {
r.logger.Error("failed to embed lines", "err", err.Error())
errCh <- err
return err
}
if len(embeddings) == 0 {
err := errors.New("no embeddings returned")
r.logger.Error("empty embeddings")
errCh <- err
return err
}
if len(embeddings) != len(nonEmptyLines) {
err := errors.New("mismatch between number of lines and embeddings returned")
r.logger.Error("embedding mismatch", "err", err.Error())
errCh <- err
return err
}
// Create a VectorRow for each line in the batch
vectors := make([]models.VectorRow, len(nonEmptyLines))
for i, line := range nonEmptyLines {
vectors[i] = models.VectorRow{
Embeddings: embeddings[i],
RawText: line,
Slug: fmt.Sprintf("%s_%d", slug, i),
FileName: filename,
}
}
vectorCh <- vectors
return nil return nil
} }

View File

@@ -387,8 +387,7 @@ func makeRAGTable(fileList []string) *tview.Flex {
if err := ragger.LoadRAG(fpath); err != nil { if err := ragger.LoadRAG(fpath); err != nil {
logger.Error("failed to embed file", "chat", fpath, "error", err) logger.Error("failed to embed file", "chat", fpath, "error", err)
_ = notifyUser("RAG", "failed to embed file; error: "+err.Error()) _ = notifyUser("RAG", "failed to embed file; error: "+err.Error())
errCh <- err pages.RemovePage(RAGPage)
// pages.RemovePage(RAGPage)
return return
} }
}() }()