Enha (rag): async writes
This commit is contained in:
446
rag/rag.go
446
rag/rag.go
@@ -1,6 +1,7 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gf-lt/config"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"log/slog"
|
||||
"path"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -17,9 +19,14 @@ import (
|
||||
"github.com/neurosnap/sentences/english"
|
||||
)
|
||||
|
||||
const (
|
||||
// batchTimeout is the maximum time allowed for embedding a single batch
|
||||
batchTimeout = 2 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
// Status messages for TUI integration
|
||||
LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking
|
||||
LongJobStatusCh = make(chan string, 100) // Increased buffer size for parallel batch updates
|
||||
FinishedRAGStatus = "finished loading RAG file; press Enter"
|
||||
LoadedFileRAGStatus = "loaded file"
|
||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||
@@ -31,12 +38,38 @@ type RAG struct {
|
||||
cfg *config.Config
|
||||
embedder Embedder
|
||||
storage *VectorStorage
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
idleMu sync.Mutex
|
||||
fallbackMsg string
|
||||
idleTimer *time.Timer
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
// batchTask represents a single batch to be embedded
|
||||
type batchTask struct {
|
||||
batchIndex int
|
||||
paragraphs []string
|
||||
filename string
|
||||
totalBatches int
|
||||
}
|
||||
|
||||
// batchResult represents the result of embedding a batch
|
||||
type batchResult struct {
|
||||
batchIndex int
|
||||
embeddings [][]float32
|
||||
paragraphs []string
|
||||
filename string
|
||||
}
|
||||
|
||||
// sendStatusNonBlocking sends a status message without blocking
|
||||
func (r *RAG) sendStatusNonBlocking(status string) {
|
||||
select {
|
||||
case LongJobStatusCh <- status:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", status)
|
||||
}
|
||||
}
|
||||
|
||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||
var embedder Embedder
|
||||
var fallbackMsg string
|
||||
@@ -142,18 +175,22 @@ func sanitizeFTSQuery(query string) string {
|
||||
}
|
||||
|
||||
func (r *RAG) LoadRAG(fpath string) error {
|
||||
return r.LoadRAGWithContext(context.Background(), fpath)
|
||||
}
|
||||
|
||||
func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
fileText, err := ExtractText(fpath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.logger.Debug("rag: loaded file", "fp", fpath)
|
||||
select {
|
||||
case LongJobStatusCh <- LoadedFileRAGStatus:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
|
||||
}
|
||||
|
||||
// Send initial status (non-blocking with retry)
|
||||
r.sendStatusNonBlocking(LoadedFileRAGStatus)
|
||||
|
||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -163,6 +200,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
for i, s := range sentences {
|
||||
sents[i] = s.Text
|
||||
}
|
||||
|
||||
// Create chunks with overlap
|
||||
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
|
||||
// Adjust batch size if needed
|
||||
@@ -172,76 +210,332 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
if len(paragraphs) == 0 {
|
||||
return errors.New("no valid paragraphs found in file")
|
||||
}
|
||||
// Process paragraphs in batches synchronously
|
||||
batchCount := 0
|
||||
for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize {
|
||||
end := i + r.cfg.RAGBatchSize
|
||||
if end > len(paragraphs) {
|
||||
end = len(paragraphs)
|
||||
}
|
||||
batch := paragraphs[i:end]
|
||||
batchCount++
|
||||
// Filter empty paragraphs
|
||||
nonEmptyBatch := make([]string, 0, len(batch))
|
||||
for _, p := range batch {
|
||||
if strings.TrimSpace(p) != "" {
|
||||
nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
|
||||
|
||||
totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize
|
||||
r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize)
|
||||
|
||||
// Determine concurrency level
|
||||
concurrency := runtime.NumCPU()
|
||||
if concurrency > totalBatches {
|
||||
concurrency = totalBatches
|
||||
}
|
||||
if concurrency < 1 {
|
||||
concurrency = 1
|
||||
}
|
||||
// If using ONNX embedder, limit concurrency to 1 due to mutex serialization
|
||||
isONNX := false
|
||||
if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX {
|
||||
concurrency = 1
|
||||
}
|
||||
embedderType := "API"
|
||||
if isONNX {
|
||||
embedderType = "ONNX"
|
||||
}
|
||||
r.logger.Debug("parallel embedding setup",
|
||||
"total_batches", totalBatches,
|
||||
"concurrency", concurrency,
|
||||
"embedder", embedderType,
|
||||
"batch_size", r.cfg.RAGBatchSize)
|
||||
|
||||
// Create context with timeout (30 minutes) and cancellation for error handling
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Channels for task distribution and results
|
||||
taskCh := make(chan batchTask, totalBatches)
|
||||
resultCh := make(chan batchResult, totalBatches)
|
||||
errorCh := make(chan error, totalBatches)
|
||||
|
||||
// Start worker goroutines
|
||||
var wg sync.WaitGroup
|
||||
for w := 0; w < concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go r.embeddingWorker(ctx, w, taskCh, resultCh, errorCh, &wg)
|
||||
}
|
||||
|
||||
// Close task channel after all tasks are sent (by separate goroutine)
|
||||
go func() {
|
||||
// Ensure task channel is closed when this goroutine exits
|
||||
defer close(taskCh)
|
||||
r.logger.Debug("task distributor started", "total_batches", totalBatches)
|
||||
|
||||
for i := 0; i < totalBatches; i++ {
|
||||
start := i * r.cfg.RAGBatchSize
|
||||
end := start + r.cfg.RAGBatchSize
|
||||
if end > len(paragraphs) {
|
||||
end = len(paragraphs)
|
||||
}
|
||||
batch := paragraphs[start:end]
|
||||
|
||||
// Filter empty paragraphs
|
||||
nonEmptyBatch := make([]string, 0, len(batch))
|
||||
for _, p := range batch {
|
||||
if strings.TrimSpace(p) != "" {
|
||||
nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
|
||||
}
|
||||
}
|
||||
|
||||
task := batchTask{
|
||||
batchIndex: i,
|
||||
paragraphs: nonEmptyBatch,
|
||||
filename: path.Base(fpath),
|
||||
totalBatches: totalBatches,
|
||||
}
|
||||
|
||||
select {
|
||||
case taskCh <- task:
|
||||
r.logger.Debug("task distributor sent batch", "batch", i, "paragraphs", len(nonEmptyBatch))
|
||||
case <-ctx.Done():
|
||||
r.logger.Debug("task distributor cancelled", "batches_sent", i+1, "total_batches", totalBatches)
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(nonEmptyBatch) == 0 {
|
||||
r.logger.Debug("task distributor finished", "batches_sent", totalBatches)
|
||||
}()
|
||||
|
||||
// Wait for workers to finish and close result channel
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
// Process results in order and write to database
|
||||
nextExpectedBatch := 0
|
||||
resultsBuffer := make(map[int]batchResult)
|
||||
filename := path.Base(fpath)
|
||||
batchesProcessed := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
case err := <-errorCh:
|
||||
// First error from any worker, cancel everything
|
||||
cancel()
|
||||
r.logger.Error("embedding worker failed", "error", err)
|
||||
r.sendStatusNonBlocking(ErrRAGStatus)
|
||||
return fmt.Errorf("embedding failed: %w", err)
|
||||
|
||||
case result, ok := <-resultCh:
|
||||
if !ok {
|
||||
// All results processed
|
||||
resultCh = nil
|
||||
r.logger.Debug("result channel closed", "batches_processed", batchesProcessed, "total_batches", totalBatches)
|
||||
continue
|
||||
}
|
||||
|
||||
// Store result in buffer
|
||||
resultsBuffer[result.batchIndex] = result
|
||||
|
||||
// Process buffered results in order
|
||||
for {
|
||||
if res, exists := resultsBuffer[nextExpectedBatch]; exists {
|
||||
// Write this batch to database
|
||||
if err := r.writeBatchToStorage(ctx, res, filename); err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
|
||||
batchesProcessed++
|
||||
// Send progress update
|
||||
statusMsg := fmt.Sprintf("processed batch %d/%d", batchesProcessed, totalBatches)
|
||||
r.sendStatusNonBlocking(statusMsg)
|
||||
|
||||
delete(resultsBuffer, nextExpectedBatch)
|
||||
nextExpectedBatch++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// No channels ready, check for deadlock conditions
|
||||
if resultCh == nil && nextExpectedBatch < totalBatches {
|
||||
// Missing batch results after result channel closed
|
||||
r.logger.Error("missing batch results",
|
||||
"expected", totalBatches,
|
||||
"received", nextExpectedBatch,
|
||||
"missing", totalBatches-nextExpectedBatch)
|
||||
|
||||
// Wait a short time for any delayed errors, then cancel
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
cancel()
|
||||
return fmt.Errorf("missing batch results: expected %d, got %d", totalBatches, nextExpectedBatch)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errorCh:
|
||||
cancel()
|
||||
r.logger.Error("embedding worker failed after result channel closed", "error", err)
|
||||
r.sendStatusNonBlocking(ErrRAGStatus)
|
||||
return fmt.Errorf("embedding failed: %w", err)
|
||||
}
|
||||
}
|
||||
// If we reach here, no deadlock yet, just busy loop prevention
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Check if we're done
|
||||
if resultCh == nil && nextExpectedBatch >= totalBatches {
|
||||
r.logger.Debug("all batches processed successfully", "total", totalBatches)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Debug("finished writing vectors", "batches", batchesProcessed)
|
||||
r.resetIdleTimer()
|
||||
r.sendStatusNonBlocking(FinishedRAGStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
// embeddingWorker processes batch embedding tasks
|
||||
func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan batchTask, resultCh chan<- batchResult, errorCh chan<- error, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
r.logger.Debug("embedding worker started", "worker", workerID)
|
||||
|
||||
// Panic recovery to ensure worker doesn't crash silently
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
r.logger.Error("embedding worker panicked", "worker", workerID, "panic", rec)
|
||||
// Try to send error, but don't block if channel is full
|
||||
select {
|
||||
case errorCh <- fmt.Errorf("worker %d panicked: %v", workerID, rec):
|
||||
default:
|
||||
r.logger.Warn("error channel full, dropping panic error", "worker", workerID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for task := range taskCh {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.logger.Debug("embedding worker cancelled", "worker", workerID)
|
||||
return
|
||||
default:
|
||||
}
|
||||
r.logger.Debug("worker processing batch", "worker", workerID, "batch", task.batchIndex, "paragraphs", len(task.paragraphs), "total_batches", task.totalBatches)
|
||||
|
||||
// Skip empty batches
|
||||
if len(task.paragraphs) == 0 {
|
||||
select {
|
||||
case resultCh <- batchResult{
|
||||
batchIndex: task.batchIndex,
|
||||
embeddings: nil,
|
||||
paragraphs: nil,
|
||||
filename: task.filename,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
r.logger.Debug("embedding worker cancelled while sending empty batch", "worker", workerID)
|
||||
return
|
||||
}
|
||||
r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex)
|
||||
continue
|
||||
}
|
||||
// Embed the batch
|
||||
embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch)
|
||||
|
||||
// Embed with retry for API embedder
|
||||
embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to embed batch", "error", err, "batch", batchCount)
|
||||
// Try to send error, but don't block indefinitely
|
||||
select {
|
||||
case LongJobStatusCh <- ErrRAGStatus:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
||||
case errorCh <- fmt.Errorf("worker %d batch %d: %w", workerID, task.batchIndex, err):
|
||||
case <-ctx.Done():
|
||||
r.logger.Debug("embedding worker cancelled while sending error", "worker", workerID)
|
||||
}
|
||||
return fmt.Errorf("failed to embed batch %d: %w", batchCount, err)
|
||||
return
|
||||
}
|
||||
if len(embeddings) != len(nonEmptyBatch) {
|
||||
err := errors.New("embedding count mismatch")
|
||||
r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings))
|
||||
return err
|
||||
}
|
||||
// Write vectors to storage
|
||||
filename := path.Base(fpath)
|
||||
for j, text := range nonEmptyBatch {
|
||||
vector := models.VectorRow{
|
||||
Embeddings: embeddings[j],
|
||||
RawText: text,
|
||||
Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j),
|
||||
FileName: filename,
|
||||
}
|
||||
if err := r.storage.WriteVector(&vector); err != nil {
|
||||
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
||||
select {
|
||||
case LongJobStatusCh <- ErrRAGStatus:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
||||
}
|
||||
return fmt.Errorf("failed to write vector: %w", err)
|
||||
}
|
||||
}
|
||||
r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch))
|
||||
// Send progress status
|
||||
statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize)
|
||||
|
||||
// Send result with context awareness
|
||||
select {
|
||||
case LongJobStatusCh <- statusMsg:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
||||
case resultCh <- batchResult{
|
||||
batchIndex: task.batchIndex,
|
||||
embeddings: embeddings,
|
||||
paragraphs: task.paragraphs,
|
||||
filename: task.filename,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
r.logger.Debug("embedding worker cancelled while sending result", "worker", workerID)
|
||||
return
|
||||
}
|
||||
r.logger.Debug("worker completed batch", "worker", workerID, "batch", task.batchIndex, "embeddings", len(embeddings))
|
||||
}
|
||||
r.logger.Debug("embedding worker finished", "worker", workerID)
|
||||
}
|
||||
|
||||
// embedWithRetry attempts embedding with exponential backoff for API embedder
|
||||
func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Exponential backoff
|
||||
backoff := time.Duration(attempt*attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries)
|
||||
}
|
||||
|
||||
embeddings, err := r.embedder.EmbedSlice(paragraphs)
|
||||
if err == nil {
|
||||
// Validate embedding count
|
||||
if len(embeddings) != len(paragraphs) {
|
||||
return nil, fmt.Errorf("embedding count mismatch: expected %d, got %d", len(paragraphs), len(embeddings))
|
||||
}
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// Only retry for API embedder errors (network/timeout)
|
||||
// For ONNX embedder, fail fast
|
||||
if _, isAPI := r.embedder.(*APIEmbedder); !isAPI {
|
||||
break
|
||||
}
|
||||
}
|
||||
r.logger.Debug("finished writing vectors", "batches", batchCount)
|
||||
r.resetIdleTimer()
|
||||
select {
|
||||
case LongJobStatusCh <- FinishedRAGStatus:
|
||||
default:
|
||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
||||
|
||||
return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// writeBatchToStorage writes a single batch of vectors to the database
|
||||
func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filename string) error {
|
||||
if len(result.embeddings) == 0 {
|
||||
// Empty batch, skip
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check context before starting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Build all vectors for batch write
|
||||
vectors := make([]*models.VectorRow, 0, len(result.paragraphs))
|
||||
for j, text := range result.paragraphs {
|
||||
vectors = append(vectors, &models.VectorRow{
|
||||
Embeddings: result.embeddings[j],
|
||||
RawText: text,
|
||||
Slug: fmt.Sprintf("%s_%d_%d", filename, result.batchIndex+1, j),
|
||||
FileName: filename,
|
||||
})
|
||||
}
|
||||
|
||||
// Write all vectors in a single transaction
|
||||
if err := r.storage.WriteVectors(vectors); err != nil {
|
||||
r.logger.Error("failed to write vectors batch to DB", "error", err, "batch", result.batchIndex+1, "size", len(vectors))
|
||||
r.sendStatusNonBlocking(ErrRAGStatus)
|
||||
return fmt.Errorf("failed to write vectors batch: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -250,22 +544,26 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||
return r.embedder.Embed(line)
|
||||
}
|
||||
|
||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
|
||||
func (r *RAG) searchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
|
||||
r.resetIdleTimer()
|
||||
return r.storage.SearchClosest(emb.Embedding, limit)
|
||||
}
|
||||
|
||||
func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) {
|
||||
func (r *RAG) searchKeyword(query string, limit int) ([]models.VectorRow, error) {
|
||||
r.resetIdleTimer()
|
||||
sanitized := sanitizeFTSQuery(query)
|
||||
return r.storage.SearchKeyword(sanitized, limit)
|
||||
}
|
||||
|
||||
func (r *RAG) ListLoaded() ([]string, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.storage.ListFiles()
|
||||
}
|
||||
|
||||
func (r *RAG) RemoveFile(filename string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.resetIdleTimer()
|
||||
return r.storage.RemoveEmbByFileName(filename)
|
||||
}
|
||||
@@ -454,6 +752,9 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
||||
}
|
||||
|
||||
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
r.resetIdleTimer()
|
||||
if len(results) == 0 {
|
||||
return "No relevant information found in the vector database.", nil
|
||||
}
|
||||
@@ -482,7 +783,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
|
||||
Embedding: emb,
|
||||
Index: 0,
|
||||
}
|
||||
topResults, err := r.SearchEmb(embResp, 1)
|
||||
topResults, err := r.searchEmb(embResp, 1)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to search for synthesis context", "error", err)
|
||||
return "", err
|
||||
@@ -509,6 +810,9 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
|
||||
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
r.resetIdleTimer()
|
||||
refined := r.RefineQuery(query)
|
||||
variations := r.GenerateQueryVariations(refined)
|
||||
|
||||
@@ -525,7 +829,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
Embedding: emb,
|
||||
Index: 0,
|
||||
}
|
||||
results, err := r.SearchEmb(embResp, limit*2) // Get more candidates
|
||||
results, err := r.searchEmb(embResp, limit*2) // Get more candidates
|
||||
if err != nil {
|
||||
r.logger.Error("failed to search embeddings", "error", err, "query", q)
|
||||
continue
|
||||
@@ -543,7 +847,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
})
|
||||
|
||||
// Perform keyword search
|
||||
kwResults, err := r.SearchKeyword(refined, limit*2)
|
||||
kwResults, err := r.searchKeyword(refined, limit*2)
|
||||
if err != nil {
|
||||
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
|
||||
kwResults = nil
|
||||
@@ -621,6 +925,8 @@ func GetInstance() *RAG {
|
||||
}
|
||||
|
||||
func (r *RAG) resetIdleTimer() {
|
||||
r.idleMu.Lock()
|
||||
defer r.idleMu.Unlock()
|
||||
if r.idleTimer != nil {
|
||||
r.idleTimer.Stop()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user