diff --git a/rag/rag.go b/rag/rag.go index d8b6978..71c4ce8 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -23,7 +23,6 @@ var ( ErrRAGStatus = "some error occurred; failed to transfer data to vector db" ) - type RAG struct { logger *slog.Logger store storage.FullRepo @@ -122,10 +121,11 @@ func (r *RAG) LoadRAG(fpath string) error { batchCh = make(chan map[int][]string, maxChSize) vectorCh = make(chan []models.VectorRow, maxChSize) errCh = make(chan error, 1) + doneCh = make(chan struct{}) wg = new(sync.WaitGroup) - lock = new(sync.Mutex) ) + defer close(doneCh) defer close(errCh) defer close(batchCh) @@ -156,18 +156,20 @@ func (r *RAG) LoadRAG(fpath string) error { for w := 0; w < int(r.cfg.RAGWorkers); w++ { go func(workerID int) { defer wg.Done() - r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath)) + r.batchToVectorAsync(workerID, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) }(w) } - // Use a goroutine to close the batchCh when all batches are sent + // Close batchCh to signal workers no more data is coming + close(batchCh) + + // Wait for all workers to finish, then close vectorCh go func() { wg.Wait() - close(vectorCh) // Close vectorCh when all workers are done + close(vectorCh) }() - // Check for errors from workers - // Use a non-blocking check for errors + // Check for errors from workers - this will block until an error occurs or all workers finish select { case err := <-errCh: if err != nil { @@ -179,12 +181,28 @@ func (r *RAG) LoadRAG(fpath string) error { } // Write vectors to storage - this will block until vectorCh is closed - return r.writeVectors(vectorCh) + return r.writeVectors(vectorCh, errCh) } -func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { +func (r *RAG) writeVectors(vectorCh chan []models.VectorRow, errCh chan error) error { + // Use a select to handle both vectorCh and errCh for { - for batch := range vectorCh { + select { + case err := <-errCh: + if err != nil { + r.logger.Error("error during RAG processing in writeVectors", "error", err) + return err + } + case batch, ok := <-vectorCh: + if !ok { + r.logger.Debug("vector channel closed, finished writing vectors") + select { + case LongJobStatusCh <- FinishedRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus) + } + return nil + } for _, vector := range batch { if err := r.storage.WriteVector(&vector); err != nil { r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug) @@ -192,74 +210,57 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { case LongJobStatusCh <- ErrRAGStatus: default: r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus) - // Channel is full or closed, ignore the message to prevent panic } - return err // Stop the entire RAG operation on DB error + return err } } - r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) - if len(vectorCh) == 0 { - r.logger.Debug("finished writing vectors") - select { - case LongJobStatusCh <- FinishedRAGStatus: - default: - r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus) - // Channel is full or closed, ignore the message to prevent panic - } - return nil - } + r.logger.Debug("wrote batch to db", "size", len(batch)) } } } -func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, - vectorCh chan<- []models.VectorRow, errCh chan error, filename string) { +func (r *RAG) batchToVectorAsync(id int, inputCh <-chan map[int][]string, + vectorCh chan<- []models.VectorRow, errCh chan error, doneCh <-chan struct{}, filename string) { var err error defer func() { - // For errCh, make sure we only send if there's actually an error and the channel can accept it if err != nil { select { case errCh <- err: default: - // errCh might be full or closed, log but don't panic r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err) } } }() for { - lock.Lock() - if len(inputCh) == 0 { - lock.Unlock() - return - } - select { - case linesMap := <-inputCh: + case <-doneCh: + r.logger.Debug("worker received done signal", "worker", id) + return + case linesMap, ok := <-inputCh: + if !ok { + r.logger.Debug("input channel closed, worker exiting", "worker", id) + return + } for leftI, lines := range linesMap { + select { + case <-doneCh: + return + default: + } if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil { r.logger.Error("error fetching embeddings", "error", err, "worker", id) - lock.Unlock() return } } - lock.Unlock() - case err = <-errCh: - r.logger.Error("got an error from error channel", "error", err) - lock.Unlock() - return - default: - lock.Unlock() - } - - r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id) - statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) - select { - case LongJobStatusCh <- statusMsg: - default: - r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg) - // Channel is full or closed, ignore the message to prevent panic + r.logger.Debug("processed batch", "worker#", id) + statusMsg := fmt.Sprintf("converted to vector; worker#: %d", id) + select { + case LongJobStatusCh <- statusMsg: + default: + r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg) + } } } }