Fix: rag panics
This commit is contained in:
2
bot.go
2
bot.go
@@ -429,13 +429,11 @@ func chatRagUse(qText string) (string, error) {
|
|||||||
logger.Error("failed to get embs", "error", err, "index", i, "question", q)
|
logger.Error("failed to get embs", "error", err, "index", i, "question", q)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create EmbeddingResp struct for the search
|
// Create EmbeddingResp struct for the search
|
||||||
embeddingResp := &models.EmbeddingResp{
|
embeddingResp := &models.EmbeddingResp{
|
||||||
Embedding: emb,
|
Embedding: emb,
|
||||||
Index: 0, // Not used in search but required for the struct
|
Index: 0, // Not used in search but required for the struct
|
||||||
}
|
}
|
||||||
|
|
||||||
vecs, err := ragger.SearchEmb(embeddingResp)
|
vecs, err := ragger.SearchEmb(embeddingResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to query embs", "error", err, "index", i, "question", q)
|
logger.Error("failed to query embs", "error", err, "index", i, "question", q)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ DeepSeekModel = "deepseek-reasoner"
|
|||||||
OpenRouterCompletionAPI = "https://openrouter.ai/api/v1/completions"
|
OpenRouterCompletionAPI = "https://openrouter.ai/api/v1/completions"
|
||||||
OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions"
|
OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions"
|
||||||
# OpenRouterToken = ""
|
# OpenRouterToken = ""
|
||||||
EmbedURL = "http://localhost:8080/v1/embeddings"
|
EmbedURL = "http://localhost:8082/v1/embeddings"
|
||||||
ShowSys = true
|
ShowSys = true
|
||||||
LogFile = "log.txt"
|
LogFile = "log.txt"
|
||||||
UserRole = "user"
|
UserRole = "user"
|
||||||
@@ -19,7 +19,7 @@ AssistantRole = "assistant"
|
|||||||
SysDir = "sysprompts"
|
SysDir = "sysprompts"
|
||||||
ChunkLimit = 100000
|
ChunkLimit = 100000
|
||||||
# rag settings
|
# rag settings
|
||||||
RAGBatchSize = 10
|
RAGBatchSize = 1
|
||||||
RAGWordLimit = 80
|
RAGWordLimit = 80
|
||||||
RAGWorkers = 2
|
RAGWorkers = 2
|
||||||
RAGDir = "ragimport"
|
RAGDir = "ragimport"
|
||||||
|
|||||||
87
rag/rag.go
87
rag/rag.go
@@ -23,6 +23,7 @@ var (
|
|||||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type RAG struct {
|
type RAG struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
store storage.FullRepo
|
store storage.FullRepo
|
||||||
@@ -58,7 +59,12 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.logger.Debug("rag: loaded file", "fp", fpath)
|
r.logger.Debug("rag: loaded file", "fp", fpath)
|
||||||
LongJobStatusCh <- LoadedFileRAGStatus
|
select {
|
||||||
|
case LongJobStatusCh <- LoadedFileRAGStatus:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
|
||||||
|
// Channel is full or closed, ignore the message to prevent panic
|
||||||
|
}
|
||||||
|
|
||||||
fileText := string(data)
|
fileText := string(data)
|
||||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
tokenizer, err := english.NewSentenceTokenizer(nil)
|
||||||
@@ -116,11 +122,10 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
batchCh = make(chan map[int][]string, maxChSize)
|
batchCh = make(chan map[int][]string, maxChSize)
|
||||||
vectorCh = make(chan []models.VectorRow, maxChSize)
|
vectorCh = make(chan []models.VectorRow, maxChSize)
|
||||||
errCh = make(chan error, 1)
|
errCh = make(chan error, 1)
|
||||||
doneCh = make(chan bool, 1)
|
wg = new(sync.WaitGroup)
|
||||||
lock = new(sync.Mutex)
|
lock = new(sync.Mutex)
|
||||||
)
|
)
|
||||||
|
|
||||||
defer close(doneCh)
|
|
||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
defer close(batchCh)
|
defer close(batchCh)
|
||||||
|
|
||||||
@@ -139,19 +144,41 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
|
|
||||||
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents))
|
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents))
|
||||||
r.logger.Debug(finishedBatchesMsg)
|
r.logger.Debug(finishedBatchesMsg)
|
||||||
LongJobStatusCh <- finishedBatchesMsg
|
select {
|
||||||
|
case LongJobStatusCh <- finishedBatchesMsg:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg)
|
||||||
|
// Channel is full or closed, ignore the message to prevent panic
|
||||||
|
}
|
||||||
|
|
||||||
// Start worker goroutines
|
// Start worker goroutines with WaitGroup
|
||||||
|
wg.Add(int(r.cfg.RAGWorkers))
|
||||||
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
||||||
go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
go func(workerID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath))
|
||||||
|
}(w)
|
||||||
}
|
}
|
||||||
// Wait for embedding to be done
|
|
||||||
<-doneCh
|
// Use a goroutine to close the batchCh when all batches are sent
|
||||||
err = <-errCh
|
go func() {
|
||||||
if err != nil {
|
wg.Wait()
|
||||||
return err
|
close(vectorCh) // Close vectorCh when all workers are done
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Check for errors from workers
|
||||||
|
// Use a non-blocking check for errors
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("error during RAG processing", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// No immediate error, continue
|
||||||
}
|
}
|
||||||
// Write vectors to storage
|
|
||||||
|
// Write vectors to storage - this will block until vectorCh is closed
|
||||||
return r.writeVectors(vectorCh)
|
return r.writeVectors(vectorCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,14 +188,24 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
|||||||
for _, vector := range batch {
|
for _, vector := range batch {
|
||||||
if err := r.storage.WriteVector(&vector); err != nil {
|
if err := r.storage.WriteVector(&vector); err != nil {
|
||||||
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
||||||
LongJobStatusCh <- ErrRAGStatus
|
select {
|
||||||
|
case LongJobStatusCh <- ErrRAGStatus:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
|
||||||
|
// Channel is full or closed, ignore the message to prevent panic
|
||||||
|
}
|
||||||
return err // Stop the entire RAG operation on DB error
|
return err // Stop the entire RAG operation on DB error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
||||||
if len(vectorCh) == 0 {
|
if len(vectorCh) == 0 {
|
||||||
r.logger.Debug("finished writing vectors")
|
r.logger.Debug("finished writing vectors")
|
||||||
LongJobStatusCh <- FinishedRAGStatus
|
select {
|
||||||
|
case LongJobStatusCh <- FinishedRAGStatus:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
||||||
|
// Channel is full or closed, ignore the message to prevent panic
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -176,12 +213,18 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
||||||
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
|
vectorCh chan<- []models.VectorRow, errCh chan error, filename string) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if len(doneCh) == 0 {
|
// For errCh, make sure we only send if there's actually an error and the channel can accept it
|
||||||
doneCh <- true
|
if err != nil {
|
||||||
errCh <- err
|
select {
|
||||||
|
case errCh <- err:
|
||||||
|
default:
|
||||||
|
// errCh might be full or closed, log but don't panic
|
||||||
|
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -211,7 +254,13 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id)
|
r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id)
|
||||||
LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
|
statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
|
||||||
|
select {
|
||||||
|
case LongJobStatusCh <- statusMsg:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
|
||||||
|
// Channel is full or closed, ignore the message to prevent panic
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user