diff --git a/rag/embedder.go b/rag/embedder.go index 39f4b5c..fd4cfa7 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "sync" + "time" "github.com/sugarme/tokenizer" "github.com/sugarme/tokenizer/pretrained" @@ -33,8 +34,10 @@ type APIEmbedder struct { func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder { return &APIEmbedder{ logger: l, - client: &http.Client{}, - cfg: cfg, + client: &http.Client{ + Timeout: 30 * time.Second, + }, + cfg: cfg, } } diff --git a/rag/rag.go b/rag/rag.go index 9271b60..180ad50 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -1,6 +1,7 @@ package rag import ( + "context" "errors" "fmt" "gf-lt/config" @@ -9,6 +10,7 @@ import ( "log/slog" "path" "regexp" + "runtime" "sort" "strings" "sync" @@ -17,9 +19,14 @@ import ( "github.com/neurosnap/sentences/english" ) +const ( + // batchTimeout is the maximum time allowed for embedding a single batch + batchTimeout = 2 * time.Minute +) + var ( // Status messages for TUI integration - LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking + LongJobStatusCh = make(chan string, 100) // Increased buffer size for parallel batch updates FinishedRAGStatus = "finished loading RAG file; press Enter" LoadedFileRAGStatus = "loaded file" ErrRAGStatus = "some error occurred; failed to transfer data to vector db" @@ -31,12 +38,38 @@ type RAG struct { cfg *config.Config embedder Embedder storage *VectorStorage - mu sync.Mutex + mu sync.RWMutex + idleMu sync.Mutex fallbackMsg string idleTimer *time.Timer idleTimeout time.Duration } +// batchTask represents a single batch to be embedded +type batchTask struct { + batchIndex int + paragraphs []string + filename string + totalBatches int +} + +// batchResult represents the result of embedding a batch +type batchResult struct { + batchIndex int + embeddings [][]float32 + paragraphs []string + filename string +} + +// sendStatusNonBlocking sends a status message without blocking +func (r *RAG) sendStatusNonBlocking(status string) { + select { + case LongJobStatusCh <- status: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", status) + } +} + func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) { var embedder Embedder var fallbackMsg string @@ -142,18 +175,22 @@ func sanitizeFTSQuery(query string) string { } func (r *RAG) LoadRAG(fpath string) error { + return r.LoadRAGWithContext(context.Background(), fpath) +} + +func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { r.mu.Lock() defer r.mu.Unlock() + fileText, err := ExtractText(fpath) if err != nil { return err } r.logger.Debug("rag: loaded file", "fp", fpath) - select { - case LongJobStatusCh <- LoadedFileRAGStatus: - default: - r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus) - } + + // Send initial status (non-blocking with retry) + r.sendStatusNonBlocking(LoadedFileRAGStatus) + tokenizer, err := english.NewSentenceTokenizer(nil) if err != nil { return err @@ -163,6 +200,7 @@ func (r *RAG) LoadRAG(fpath string) error { for i, s := range sentences { sents[i] = s.Text } + // Create chunks with overlap paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords) // Adjust batch size if needed @@ -172,76 +210,332 @@ func (r *RAG) LoadRAG(fpath string) error { if len(paragraphs) == 0 { return errors.New("no valid paragraphs found in file") } - // Process paragraphs in batches synchronously - batchCount := 0 - for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize { - end := i + r.cfg.RAGBatchSize - if end > len(paragraphs) { - end = len(paragraphs) - } - batch := paragraphs[i:end] - batchCount++ - // Filter empty paragraphs - nonEmptyBatch := make([]string, 0, len(batch)) - for _, p := range batch { - if strings.TrimSpace(p) != "" { - nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p)) + + totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize + r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize) + + // Determine concurrency level + concurrency := runtime.NumCPU() + if concurrency > totalBatches { + concurrency = totalBatches + } + if concurrency < 1 { + concurrency = 1 + } + // If using ONNX embedder, limit concurrency to 1 due to mutex serialization + isONNX := false + if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX { + concurrency = 1 + } + embedderType := "API" + if isONNX { + embedderType = "ONNX" + } + r.logger.Debug("parallel embedding setup", + "total_batches", totalBatches, + "concurrency", concurrency, + "embedder", embedderType, + "batch_size", r.cfg.RAGBatchSize) + + // Create context with timeout (30 minutes) and cancellation for error handling + ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) + defer cancel() + + // Channels for task distribution and results + taskCh := make(chan batchTask, totalBatches) + resultCh := make(chan batchResult, totalBatches) + errorCh := make(chan error, totalBatches) + + // Start worker goroutines + var wg sync.WaitGroup + for w := 0; w < concurrency; w++ { + wg.Add(1) + go r.embeddingWorker(ctx, w, taskCh, resultCh, errorCh, &wg) + } + + // Close task channel after all tasks are sent (by separate goroutine) + go func() { + // Ensure task channel is closed when this goroutine exits + defer close(taskCh) + r.logger.Debug("task distributor started", "total_batches", totalBatches) + + for i := 0; i < totalBatches; i++ { + start := i * r.cfg.RAGBatchSize + end := start + r.cfg.RAGBatchSize + if end > len(paragraphs) { + end = len(paragraphs) + } + batch := paragraphs[start:end] + + // Filter empty paragraphs + nonEmptyBatch := make([]string, 0, len(batch)) + for _, p := range batch { + if strings.TrimSpace(p) != "" { + nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p)) + } + } + + task := batchTask{ + batchIndex: i, + paragraphs: nonEmptyBatch, + filename: path.Base(fpath), + totalBatches: totalBatches, + } + + select { + case taskCh <- task: + r.logger.Debug("task distributor sent batch", "batch", i, "paragraphs", len(nonEmptyBatch)) + case <-ctx.Done(): + r.logger.Debug("task distributor cancelled", "batches_sent", i+1, "total_batches", totalBatches) + return } } - if len(nonEmptyBatch) == 0 { + r.logger.Debug("task distributor finished", "batches_sent", totalBatches) + }() + + // Wait for workers to finish and close result channel + go func() { + wg.Wait() + close(resultCh) + }() + + // Process results in order and write to database + nextExpectedBatch := 0 + resultsBuffer := make(map[int]batchResult) + filename := path.Base(fpath) + batchesProcessed := 0 + + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case err := <-errorCh: + // First error from any worker, cancel everything + cancel() + r.logger.Error("embedding worker failed", "error", err) + r.sendStatusNonBlocking(ErrRAGStatus) + return fmt.Errorf("embedding failed: %w", err) + + case result, ok := <-resultCh: + if !ok { + // All results processed + resultCh = nil + r.logger.Debug("result channel closed", "batches_processed", batchesProcessed, "total_batches", totalBatches) + continue + } + + // Store result in buffer + resultsBuffer[result.batchIndex] = result + + // Process buffered results in order + for { + if res, exists := resultsBuffer[nextExpectedBatch]; exists { + // Write this batch to database + if err := r.writeBatchToStorage(ctx, res, filename); err != nil { + cancel() + return err + } + + batchesProcessed++ + // Send progress update + statusMsg := fmt.Sprintf("processed batch %d/%d", batchesProcessed, totalBatches) + r.sendStatusNonBlocking(statusMsg) + + delete(resultsBuffer, nextExpectedBatch) + nextExpectedBatch++ + } else { + break + } + } + + default: + // No channels ready, check for deadlock conditions + if resultCh == nil && nextExpectedBatch < totalBatches { + // Missing batch results after result channel closed + r.logger.Error("missing batch results", + "expected", totalBatches, + "received", nextExpectedBatch, + "missing", totalBatches-nextExpectedBatch) + + // Wait a short time for any delayed errors, then cancel + select { + case <-time.After(5 * time.Second): + cancel() + return fmt.Errorf("missing batch results: expected %d, got %d", totalBatches, nextExpectedBatch) + case <-ctx.Done(): + return ctx.Err() + case err := <-errorCh: + cancel() + r.logger.Error("embedding worker failed after result channel closed", "error", err) + r.sendStatusNonBlocking(ErrRAGStatus) + return fmt.Errorf("embedding failed: %w", err) + } + } + // If we reach here, no deadlock yet, just busy loop prevention + time.Sleep(100 * time.Millisecond) + } + + // Check if we're done + if resultCh == nil && nextExpectedBatch >= totalBatches { + r.logger.Debug("all batches processed successfully", "total", totalBatches) + break + } + } + + r.logger.Debug("finished writing vectors", "batches", batchesProcessed) + r.resetIdleTimer() + r.sendStatusNonBlocking(FinishedRAGStatus) + return nil +} + +// embeddingWorker processes batch embedding tasks +func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan batchTask, resultCh chan<- batchResult, errorCh chan<- error, wg *sync.WaitGroup) { + defer wg.Done() + r.logger.Debug("embedding worker started", "worker", workerID) + + // Panic recovery to ensure worker doesn't crash silently + defer func() { + if rec := recover(); rec != nil { + r.logger.Error("embedding worker panicked", "worker", workerID, "panic", rec) + // Try to send error, but don't block if channel is full + select { + case errorCh <- fmt.Errorf("worker %d panicked: %v", workerID, rec): + default: + r.logger.Warn("error channel full, dropping panic error", "worker", workerID) + } + } + }() + + for task := range taskCh { + select { + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled", "worker", workerID) + return + default: + } + r.logger.Debug("worker processing batch", "worker", workerID, "batch", task.batchIndex, "paragraphs", len(task.paragraphs), "total_batches", task.totalBatches) + + // Skip empty batches + if len(task.paragraphs) == 0 { + select { + case resultCh <- batchResult{ + batchIndex: task.batchIndex, + embeddings: nil, + paragraphs: nil, + filename: task.filename, + }: + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled while sending empty batch", "worker", workerID) + return + } + r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex) continue } - // Embed the batch - embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch) + + // Embed with retry for API embedder + embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3) if err != nil { - r.logger.Error("failed to embed batch", "error", err, "batch", batchCount) + // Try to send error, but don't block indefinitely select { - case LongJobStatusCh <- ErrRAGStatus: - default: - r.logger.Warn("LongJobStatusCh channel full, dropping message") + case errorCh <- fmt.Errorf("worker %d batch %d: %w", workerID, task.batchIndex, err): + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled while sending error", "worker", workerID) } - return fmt.Errorf("failed to embed batch %d: %w", batchCount, err) + return } - if len(embeddings) != len(nonEmptyBatch) { - err := errors.New("embedding count mismatch") - r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings)) - return err - } - // Write vectors to storage - filename := path.Base(fpath) - for j, text := range nonEmptyBatch { - vector := models.VectorRow{ - Embeddings: embeddings[j], - RawText: text, - Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j), - FileName: filename, - } - if err := r.storage.WriteVector(&vector); err != nil { - r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug) - select { - case LongJobStatusCh <- ErrRAGStatus: - default: - r.logger.Warn("LongJobStatusCh channel full, dropping message") - } - return fmt.Errorf("failed to write vector: %w", err) - } - } - r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch)) - // Send progress status - statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize) + + // Send result with context awareness select { - case LongJobStatusCh <- statusMsg: - default: - r.logger.Warn("LongJobStatusCh channel full, dropping message") + case resultCh <- batchResult{ + batchIndex: task.batchIndex, + embeddings: embeddings, + paragraphs: task.paragraphs, + filename: task.filename, + }: + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled while sending result", "worker", workerID) + return + } + r.logger.Debug("worker completed batch", "worker", workerID, "batch", task.batchIndex, "embeddings", len(embeddings)) + } + r.logger.Debug("embedding worker finished", "worker", workerID) +} + +// embedWithRetry attempts embedding with exponential backoff for API embedder +func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) { + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + // Exponential backoff + backoff := time.Duration(attempt*attempt) * time.Second + if backoff > 10*time.Second { + backoff = 10 * time.Second + } + + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } + + r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries) + } + + embeddings, err := r.embedder.EmbedSlice(paragraphs) + if err == nil { + // Validate embedding count + if len(embeddings) != len(paragraphs) { + return nil, fmt.Errorf("embedding count mismatch: expected %d, got %d", len(paragraphs), len(embeddings)) + } + return embeddings, nil + } + + lastErr = err + // Only retry for API embedder errors (network/timeout) + // For ONNX embedder, fail fast + if _, isAPI := r.embedder.(*APIEmbedder); !isAPI { + break } } - r.logger.Debug("finished writing vectors", "batches", batchCount) - r.resetIdleTimer() - select { - case LongJobStatusCh <- FinishedRAGStatus: - default: - r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus) + + return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr) +} + +// writeBatchToStorage writes a single batch of vectors to the database +func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filename string) error { + if len(result.embeddings) == 0 { + // Empty batch, skip + return nil } + + // Check context before starting + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Build all vectors for batch write + vectors := make([]*models.VectorRow, 0, len(result.paragraphs)) + for j, text := range result.paragraphs { + vectors = append(vectors, &models.VectorRow{ + Embeddings: result.embeddings[j], + RawText: text, + Slug: fmt.Sprintf("%s_%d_%d", filename, result.batchIndex+1, j), + FileName: filename, + }) + } + + // Write all vectors in a single transaction + if err := r.storage.WriteVectors(vectors); err != nil { + r.logger.Error("failed to write vectors batch to DB", "error", err, "batch", result.batchIndex+1, "size", len(vectors)) + r.sendStatusNonBlocking(ErrRAGStatus) + return fmt.Errorf("failed to write vectors batch: %w", err) + } + + r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs)) return nil } @@ -250,22 +544,26 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { return r.embedder.Embed(line) } -func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) { +func (r *RAG) searchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) { r.resetIdleTimer() return r.storage.SearchClosest(emb.Embedding, limit) } -func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) { +func (r *RAG) searchKeyword(query string, limit int) ([]models.VectorRow, error) { r.resetIdleTimer() sanitized := sanitizeFTSQuery(query) return r.storage.SearchKeyword(sanitized, limit) } func (r *RAG) ListLoaded() ([]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() return r.storage.ListFiles() } func (r *RAG) RemoveFile(filename string) error { + r.mu.Lock() + defer r.mu.Unlock() r.resetIdleTimer() return r.storage.RemoveEmbByFileName(filename) } @@ -454,6 +752,9 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V } func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + r.resetIdleTimer() if len(results) == 0 { return "No relevant information found in the vector database.", nil } @@ -482,7 +783,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string Embedding: emb, Index: 0, } - topResults, err := r.SearchEmb(embResp, 1) + topResults, err := r.searchEmb(embResp, 1) if err != nil { r.logger.Error("failed to search for synthesis context", "error", err) return "", err @@ -509,6 +810,9 @@ func truncateString(s string, maxLen int) string { } func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { + r.mu.RLock() + defer r.mu.RUnlock() + r.resetIdleTimer() refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) @@ -525,7 +829,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { Embedding: emb, Index: 0, } - results, err := r.SearchEmb(embResp, limit*2) // Get more candidates + results, err := r.searchEmb(embResp, limit*2) // Get more candidates if err != nil { r.logger.Error("failed to search embeddings", "error", err, "query", q) continue @@ -543,7 +847,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { }) // Perform keyword search - kwResults, err := r.SearchKeyword(refined, limit*2) + kwResults, err := r.searchKeyword(refined, limit*2) if err != nil { r.logger.Warn("keyword search failed, using only embeddings", "error", err) kwResults = nil @@ -621,6 +925,8 @@ func GetInstance() *RAG { } func (r *RAG) resetIdleTimer() { + r.idleMu.Lock() + defer r.idleMu.Unlock() if r.idleTimer != nil { r.idleTimer.Stop() } diff --git a/rag/storage.go b/rag/storage.go index 110cea2..1e6b013 100644 --- a/rag/storage.go +++ b/rag/storage.go @@ -102,6 +102,92 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { return nil } +// WriteVectors stores multiple embedding vectors in a single transaction +func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error { + if len(rows) == 0 { + return nil + } + // SQLite has limit of 999 parameters per statement, each row uses 4 parameters + const maxBatchSize = 200 // 200 * 4 = 800 < 999 + if len(rows) > maxBatchSize { + // Process in chunks + for i := 0; i < len(rows); i += maxBatchSize { + end := i + maxBatchSize + if end > len(rows) { + end = len(rows) + } + if err := vs.WriteVectors(rows[i:end]); err != nil { + return err + } + } + return nil + } + // All rows should have same embedding size (same model) + firstSize := len(rows[0].Embeddings) + for i, row := range rows { + if len(row.Embeddings) != firstSize { + return fmt.Errorf("embedding size mismatch: row %d has size %d, expected %d", i, len(row.Embeddings), firstSize) + } + } + tableName, err := vs.getTableName(rows[0].Embeddings) + if err != nil { + return err + } + + // Start transaction + tx, err := vs.sqlxDB.Beginx() + if err != nil { + return err + } + defer func() { + if err != nil { + tx.Rollback() + } + }() + + // Build batch insert for embeddings table + embeddingPlaceholders := make([]string, 0, len(rows)) + embeddingArgs := make([]any, 0, len(rows)*4) + for _, row := range rows { + embeddingPlaceholders = append(embeddingPlaceholders, "(?, ?, ?, ?)") + embeddingArgs = append(embeddingArgs, SerializeVector(row.Embeddings), row.Slug, row.RawText, row.FileName) + } + embeddingQuery := fmt.Sprintf( + "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES %s", + tableName, + strings.Join(embeddingPlaceholders, ", "), + ) + if _, err := tx.Exec(embeddingQuery, embeddingArgs...); err != nil { + vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows)) + return err + } + + // Build batch insert for FTS table + ftsPlaceholders := make([]string, 0, len(rows)) + ftsArgs := make([]any, 0, len(rows)*4) + embeddingSize := len(rows[0].Embeddings) + for _, row := range rows { + ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)") + ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize) + } + ftsQuery := fmt.Sprintf( + "INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES %s", + strings.Join(ftsPlaceholders, ", "), + ) + if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil { + vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows)) + return err + } + + err = tx.Commit() + if err != nil { + vs.logger.Error("failed to commit transaction", "error", err) + return err + } + vs.logger.Debug("wrote vectors batch", "batch_size", len(rows)) + return nil +} + // getTableName determines which table to use based on embedding size func (vs *VectorStorage) getTableName(emb []float32) (string, error) { size := len(emb) diff --git a/storage/storage.go b/storage/storage.go index 9ad9745..57631da 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -102,6 +102,22 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { logger.Error("failed to open db connection", "error", err) return nil } + // Enable WAL mode for better concurrency and performance + if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil { + logger.Warn("failed to enable WAL mode", "error", err) + } + if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil { + logger.Warn("failed to set synchronous mode", "error", err) + } + // Increase cache size for better performance + if _, err := db.Exec("PRAGMA cache_size = -2000;"); err != nil { + logger.Warn("failed to set cache size", "error", err) + } + // Log actual journal mode for debugging + var journalMode string + if err := db.QueryRow("PRAGMA journal_mode;").Scan(&journalMode); err == nil { + logger.Debug("SQLite journal mode", "mode", journalMode) + } p := ProviderSQL{db: db, logger: logger} if err := p.Migrate(); err != nil { logger.Error("migration failed, app cannot start", "error", err)