Fix: ragflow
This commit is contained in:
97
rag/rag.go
97
rag/rag.go
@@ -23,7 +23,6 @@ var (
|
|||||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type RAG struct {
|
type RAG struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
store storage.FullRepo
|
store storage.FullRepo
|
||||||
@@ -122,10 +121,11 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
batchCh = make(chan map[int][]string, maxChSize)
|
batchCh = make(chan map[int][]string, maxChSize)
|
||||||
vectorCh = make(chan []models.VectorRow, maxChSize)
|
vectorCh = make(chan []models.VectorRow, maxChSize)
|
||||||
errCh = make(chan error, 1)
|
errCh = make(chan error, 1)
|
||||||
|
doneCh = make(chan struct{})
|
||||||
wg = new(sync.WaitGroup)
|
wg = new(sync.WaitGroup)
|
||||||
lock = new(sync.Mutex)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
defer close(doneCh)
|
||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
defer close(batchCh)
|
defer close(batchCh)
|
||||||
|
|
||||||
@@ -156,18 +156,20 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
||||||
go func(workerID int) {
|
go func(workerID int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath))
|
r.batchToVectorAsync(workerID, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
||||||
}(w)
|
}(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a goroutine to close the batchCh when all batches are sent
|
// Close batchCh to signal workers no more data is coming
|
||||||
|
close(batchCh)
|
||||||
|
|
||||||
|
// Wait for all workers to finish, then close vectorCh
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(vectorCh) // Close vectorCh when all workers are done
|
close(vectorCh)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Check for errors from workers
|
// Check for errors from workers - this will block until an error occurs or all workers finish
|
||||||
// Use a non-blocking check for errors
|
|
||||||
select {
|
select {
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -179,12 +181,28 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write vectors to storage - this will block until vectorCh is closed
|
// Write vectors to storage - this will block until vectorCh is closed
|
||||||
return r.writeVectors(vectorCh)
|
return r.writeVectors(vectorCh, errCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
func (r *RAG) writeVectors(vectorCh chan []models.VectorRow, errCh chan error) error {
|
||||||
|
// Use a select to handle both vectorCh and errCh
|
||||||
for {
|
for {
|
||||||
for batch := range vectorCh {
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("error during RAG processing in writeVectors", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case batch, ok := <-vectorCh:
|
||||||
|
if !ok {
|
||||||
|
r.logger.Debug("vector channel closed, finished writing vectors")
|
||||||
|
select {
|
||||||
|
case LongJobStatusCh <- FinishedRAGStatus:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
for _, vector := range batch {
|
for _, vector := range batch {
|
||||||
if err := r.storage.WriteVector(&vector); err != nil {
|
if err := r.storage.WriteVector(&vector); err != nil {
|
||||||
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
||||||
@@ -192,74 +210,57 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
|||||||
case LongJobStatusCh <- ErrRAGStatus:
|
case LongJobStatusCh <- ErrRAGStatus:
|
||||||
default:
|
default:
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
|
||||||
}
|
}
|
||||||
return err // Stop the entire RAG operation on DB error
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
r.logger.Debug("wrote batch to db", "size", len(batch))
|
||||||
if len(vectorCh) == 0 {
|
|
||||||
r.logger.Debug("finished writing vectors")
|
|
||||||
select {
|
|
||||||
case LongJobStatusCh <- FinishedRAGStatus:
|
|
||||||
default:
|
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
func (r *RAG) batchToVectorAsync(id int, inputCh <-chan map[int][]string,
|
||||||
vectorCh chan<- []models.VectorRow, errCh chan error, filename string) {
|
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh <-chan struct{}, filename string) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// For errCh, make sure we only send if there's actually an error and the channel can accept it
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
select {
|
select {
|
||||||
case errCh <- err:
|
case errCh <- err:
|
||||||
default:
|
default:
|
||||||
// errCh might be full or closed, log but don't panic
|
|
||||||
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
|
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
lock.Lock()
|
|
||||||
if len(inputCh) == 0 {
|
|
||||||
lock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case linesMap := <-inputCh:
|
case <-doneCh:
|
||||||
for leftI, lines := range linesMap {
|
r.logger.Debug("worker received done signal", "worker", id)
|
||||||
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
|
return
|
||||||
r.logger.Error("error fetching embeddings", "error", err, "worker", id)
|
case linesMap, ok := <-inputCh:
|
||||||
lock.Unlock()
|
if !ok {
|
||||||
|
r.logger.Debug("input channel closed, worker exiting", "worker", id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
for leftI, lines := range linesMap {
|
||||||
lock.Unlock()
|
select {
|
||||||
case err = <-errCh:
|
case <-doneCh:
|
||||||
r.logger.Error("got an error from error channel", "error", err)
|
|
||||||
lock.Unlock()
|
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
lock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
|
||||||
r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id)
|
r.logger.Error("error fetching embeddings", "error", err, "worker", id)
|
||||||
statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.logger.Debug("processed batch", "worker#", id)
|
||||||
|
statusMsg := fmt.Sprintf("converted to vector; worker#: %d", id)
|
||||||
select {
|
select {
|
||||||
case LongJobStatusCh <- statusMsg:
|
case LongJobStatusCh <- statusMsg:
|
||||||
default:
|
default:
|
||||||
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
|
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user