Fix: ragflow
This commit is contained in:
103
rag/rag.go
103
rag/rag.go
@@ -23,7 +23,6 @@ var (
|
||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||
)
|
||||
|
||||
|
||||
type RAG struct {
|
||||
logger *slog.Logger
|
||||
store storage.FullRepo
|
||||
@@ -122,10 +121,11 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
batchCh = make(chan map[int][]string, maxChSize)
|
||||
vectorCh = make(chan []models.VectorRow, maxChSize)
|
||||
errCh = make(chan error, 1)
|
||||
doneCh = make(chan struct{})
|
||||
wg = new(sync.WaitGroup)
|
||||
lock = new(sync.Mutex)
|
||||
)
|
||||
|
||||
defer close(doneCh)
|
||||
defer close(errCh)
|
||||
defer close(batchCh)
|
||||
|
||||
@@ -156,18 +156,20 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath))
|
||||
r.batchToVectorAsync(workerID, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
||||
}(w)
|
||||
}
|
||||
|
||||
// Use a goroutine to close the batchCh when all batches are sent
|
||||
// Close batchCh to signal workers no more data is coming
|
||||
close(batchCh)
|
||||
|
||||
// Wait for all workers to finish, then close vectorCh
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(vectorCh) // Close vectorCh when all workers are done
|
||||
close(vectorCh)
|
||||
}()
|
||||
|
||||
// Check for errors from workers
|
||||
// Use a non-blocking check for errors
|
||||
// Check for errors from workers - this will block until an error occurs or all workers finish
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
@@ -179,12 +181,28 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
}
|
||||
|
||||
// Write vectors to storage - this will block until vectorCh is closed
|
||||
return r.writeVectors(vectorCh)
|
||||
return r.writeVectors(vectorCh, errCh)
|
||||
}
|
||||
|
||||
func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
||||
func (r *RAG) writeVectors(vectorCh chan []models.VectorRow, errCh chan error) error {
|
||||
// Use a select to handle both vectorCh and errCh
|
||||
for {
|
||||
for batch := range vectorCh {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
r.logger.Error("error during RAG processing in writeVectors", "error", err)
|
||||
return err
|
||||
}
|
||||
case batch, ok := <-vectorCh:
|
||||
if !ok {
|
||||
r.logger.Debug("vector channel closed, finished writing vectors")
|
||||
select {
|
||||
case LongJobStatusCh <- FinishedRAGStatus:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
for _, vector := range batch {
|
||||
if err := r.storage.WriteVector(&vector); err != nil {
|
||||
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
||||
@@ -192,74 +210,57 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
||||
case LongJobStatusCh <- ErrRAGStatus:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
|
||||
// Channel is full or closed, ignore the message to prevent panic
|
||||
}
|
||||
return err // Stop the entire RAG operation on DB error
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
||||
if len(vectorCh) == 0 {
|
||||
r.logger.Debug("finished writing vectors")
|
||||
select {
|
||||
case LongJobStatusCh <- FinishedRAGStatus:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
||||
// Channel is full or closed, ignore the message to prevent panic
|
||||
}
|
||||
return nil
|
||||
}
|
||||
r.logger.Debug("wrote batch to db", "size", len(batch))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
||||
vectorCh chan<- []models.VectorRow, errCh chan error, filename string) {
|
||||
func (r *RAG) batchToVectorAsync(id int, inputCh <-chan map[int][]string,
|
||||
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh <-chan struct{}, filename string) {
|
||||
var err error
|
||||
|
||||
defer func() {
|
||||
// For errCh, make sure we only send if there's actually an error and the channel can accept it
|
||||
if err != nil {
|
||||
select {
|
||||
case errCh <- err:
|
||||
default:
|
||||
// errCh might be full or closed, log but don't panic
|
||||
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
lock.Lock()
|
||||
if len(inputCh) == 0 {
|
||||
lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case linesMap := <-inputCh:
|
||||
case <-doneCh:
|
||||
r.logger.Debug("worker received done signal", "worker", id)
|
||||
return
|
||||
case linesMap, ok := <-inputCh:
|
||||
if !ok {
|
||||
r.logger.Debug("input channel closed, worker exiting", "worker", id)
|
||||
return
|
||||
}
|
||||
for leftI, lines := range linesMap {
|
||||
select {
|
||||
case <-doneCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
|
||||
r.logger.Error("error fetching embeddings", "error", err, "worker", id)
|
||||
lock.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
lock.Unlock()
|
||||
case err = <-errCh:
|
||||
r.logger.Error("got an error from error channel", "error", err)
|
||||
lock.Unlock()
|
||||
return
|
||||
default:
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id)
|
||||
statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
|
||||
select {
|
||||
case LongJobStatusCh <- statusMsg:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
|
||||
// Channel is full or closed, ignore the message to prevent panic
|
||||
r.logger.Debug("processed batch", "worker#", id)
|
||||
statusMsg := fmt.Sprintf("converted to vector; worker#: %d", id)
|
||||
select {
|
||||
case LongJobStatusCh <- statusMsg:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user