Fix: ragflow

This commit is contained in:
Grail Finder
2026-02-24 14:24:29 +03:00
parent 343366b12d
commit 34cd4ac141

View File

@@ -23,7 +23,6 @@ var (
ErrRAGStatus = "some error occurred; failed to transfer data to vector db" ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
) )
type RAG struct { type RAG struct {
logger *slog.Logger logger *slog.Logger
store storage.FullRepo store storage.FullRepo
@@ -122,10 +121,11 @@ func (r *RAG) LoadRAG(fpath string) error {
batchCh = make(chan map[int][]string, maxChSize) batchCh = make(chan map[int][]string, maxChSize)
vectorCh = make(chan []models.VectorRow, maxChSize) vectorCh = make(chan []models.VectorRow, maxChSize)
errCh = make(chan error, 1) errCh = make(chan error, 1)
doneCh = make(chan struct{})
wg = new(sync.WaitGroup) wg = new(sync.WaitGroup)
lock = new(sync.Mutex)
) )
defer close(doneCh)
defer close(errCh) defer close(errCh)
defer close(batchCh) defer close(batchCh)
@@ -156,18 +156,20 @@ func (r *RAG) LoadRAG(fpath string) error {
for w := 0; w < int(r.cfg.RAGWorkers); w++ { for w := 0; w < int(r.cfg.RAGWorkers); w++ {
go func(workerID int) { go func(workerID int) {
defer wg.Done() defer wg.Done()
r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath)) r.batchToVectorAsync(workerID, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
}(w) }(w)
} }
// Use a goroutine to close the batchCh when all batches are sent // Close batchCh to signal workers no more data is coming
close(batchCh)
// Wait for all workers to finish, then close vectorCh
go func() { go func() {
wg.Wait() wg.Wait()
close(vectorCh) // Close vectorCh when all workers are done close(vectorCh)
}() }()
// Check for errors from workers // Check for errors from workers - this will block until an error occurs or all workers finish
// Use a non-blocking check for errors
select { select {
case err := <-errCh: case err := <-errCh:
if err != nil { if err != nil {
@@ -179,12 +181,28 @@ func (r *RAG) LoadRAG(fpath string) error {
} }
// Write vectors to storage - this will block until vectorCh is closed // Write vectors to storage - this will block until vectorCh is closed
return r.writeVectors(vectorCh) return r.writeVectors(vectorCh, errCh)
} }
func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { func (r *RAG) writeVectors(vectorCh chan []models.VectorRow, errCh chan error) error {
// Use a select to handle both vectorCh and errCh
for { for {
for batch := range vectorCh { select {
case err := <-errCh:
if err != nil {
r.logger.Error("error during RAG processing in writeVectors", "error", err)
return err
}
case batch, ok := <-vectorCh:
if !ok {
r.logger.Debug("vector channel closed, finished writing vectors")
select {
case LongJobStatusCh <- FinishedRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
}
return nil
}
for _, vector := range batch { for _, vector := range batch {
if err := r.storage.WriteVector(&vector); err != nil { if err := r.storage.WriteVector(&vector); err != nil {
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug) r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
@@ -192,74 +210,57 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
case LongJobStatusCh <- ErrRAGStatus: case LongJobStatusCh <- ErrRAGStatus:
default: default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus) r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
// Channel is full or closed, ignore the message to prevent panic
} }
return err // Stop the entire RAG operation on DB error return err
} }
} }
r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) r.logger.Debug("wrote batch to db", "size", len(batch))
if len(vectorCh) == 0 {
r.logger.Debug("finished writing vectors")
select {
case LongJobStatusCh <- FinishedRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
// Channel is full or closed, ignore the message to prevent panic
}
return nil
}
} }
} }
} }
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, func (r *RAG) batchToVectorAsync(id int, inputCh <-chan map[int][]string,
vectorCh chan<- []models.VectorRow, errCh chan error, filename string) { vectorCh chan<- []models.VectorRow, errCh chan error, doneCh <-chan struct{}, filename string) {
var err error var err error
defer func() { defer func() {
// For errCh, make sure we only send if there's actually an error and the channel can accept it
if err != nil { if err != nil {
select { select {
case errCh <- err: case errCh <- err:
default: default:
// errCh might be full or closed, log but don't panic
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err) r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
} }
} }
}() }()
for { for {
lock.Lock()
if len(inputCh) == 0 {
lock.Unlock()
return
}
select { select {
case linesMap := <-inputCh: case <-doneCh:
for leftI, lines := range linesMap { r.logger.Debug("worker received done signal", "worker", id)
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil { return
r.logger.Error("error fetching embeddings", "error", err, "worker", id) case linesMap, ok := <-inputCh:
lock.Unlock() if !ok {
r.logger.Debug("input channel closed, worker exiting", "worker", id)
return return
} }
} for leftI, lines := range linesMap {
lock.Unlock() select {
case err = <-errCh: case <-doneCh:
r.logger.Error("got an error from error channel", "error", err)
lock.Unlock()
return return
default: default:
lock.Unlock()
} }
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id) r.logger.Error("error fetching embeddings", "error", err, "worker", id)
statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) return
}
}
r.logger.Debug("processed batch", "worker#", id)
statusMsg := fmt.Sprintf("converted to vector; worker#: %d", id)
select { select {
case LongJobStatusCh <- statusMsg: case LongJobStatusCh <- statusMsg:
default: default:
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg) r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
// Channel is full or closed, ignore the message to prevent panic }
} }
} }
} }