diff --git a/.gitignore b/.gitignore index 15b83b4..b3baaec 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ testlog history/ *.db +*.db-shm +*.db-wal config.toml sysprompts/* !sysprompts/alice_bob_carl.json diff --git a/rag/embedder.go b/rag/embedder.go index fd4cfa7..5a4aae0 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -213,7 +213,6 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log if cudaLibPath == "" { fmt.Println("WARNING: CUDA provider library not found, will use CPU") } - emb := &ONNXEmbedder{ tokenizerPath: tokenizerPath, dims: dims, @@ -232,7 +231,6 @@ func (e *ONNXEmbedder) ensureInitialized() error { if e.session != nil { return nil } - // Load tokenizer lazily if e.tokenizer == nil { tok, err := pretrained.FromFile(e.tokenizerPath) @@ -241,7 +239,6 @@ func (e *ONNXEmbedder) ensureInitialized() error { } e.tokenizer = tok } - onnxInitOnce.Do(func() { onnxruntime_go.SetSharedLibraryPath(onnxLibPath) if err := onnxruntime_go.InitializeEnvironment(); err != nil { @@ -260,13 +257,14 @@ func (e *ONNXEmbedder) ensureInitialized() error { if !onnxReady { return errors.New("ONNX runtime not ready") } - // Create session options opts, err := onnxruntime_go.NewSessionOptions() if err != nil { return fmt.Errorf("failed to create session options: %w", err) } - defer opts.Destroy() + defer func() { + _ = opts.Destroy() + }() // Try to add CUDA provider useCUDA := cudaLibPath != "" @@ -276,7 +274,9 @@ func (e *ONNXEmbedder) ensureInitialized() error { e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err) useCUDA = false } else { - defer cudaOpts.Destroy() + defer func() { + _ = cudaOpts.Destroy() + }() if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil { e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err) useCUDA = false @@ -286,7 +286,6 @@ func (e *ONNXEmbedder) ensureInitialized() error { } } } - if useCUDA { e.logger.Info("Using CUDA for ONNX inference") } else { diff --git a/rag/rag.go b/rag/rag.go index 180ad50..3db4303 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -19,10 +19,7 @@ import ( "github.com/neurosnap/sentences/english" ) -const ( - // batchTimeout is the maximum time allowed for embedding a single batch - batchTimeout = 2 * time.Minute -) +const () var ( // Status messages for TUI integration @@ -102,10 +99,6 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) { return rag, nil } -func wordCounter(sentence string) int { - return len(strings.Split(strings.TrimSpace(sentence), " ")) -} - func createChunks(sentences []string, wordLimit, overlapWords uint32) []string { if len(sentences) == 0 { return nil @@ -181,7 +174,6 @@ func (r *RAG) LoadRAG(fpath string) error { func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { r.mu.Lock() defer r.mu.Unlock() - fileText, err := ExtractText(fpath) if err != nil { return err @@ -190,7 +182,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { // Send initial status (non-blocking with retry) r.sendStatusNonBlocking(LoadedFileRAGStatus) - tokenizer, err := english.NewSentenceTokenizer(nil) if err != nil { return err @@ -210,7 +201,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { if len(paragraphs) == 0 { return errors.New("no valid paragraphs found in file") } - totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize) @@ -223,7 +213,7 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { concurrency = 1 } // If using ONNX embedder, limit concurrency to 1 due to mutex serialization - isONNX := false + var isONNX bool if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX { concurrency = 1 } @@ -258,7 +248,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { // Ensure task channel is closed when this goroutine exits defer close(taskCh) r.logger.Debug("task distributor started", "total_batches", totalBatches) - for i := 0; i < totalBatches; i++ { start := i * r.cfg.RAGBatchSize end := start + r.cfg.RAGBatchSize @@ -304,7 +293,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { resultsBuffer := make(map[int]batchResult) filename := path.Base(fpath) batchesProcessed := 0 - for { select { case <-ctx.Done(): @@ -382,7 +370,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { break } } - r.logger.Debug("finished writing vectors", "batches", batchesProcessed) r.resetIdleTimer() r.sendStatusNonBlocking(FinishedRAGStatus) @@ -406,7 +393,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b } } }() - for task := range taskCh { select { case <-ctx.Done(): @@ -432,7 +418,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex) continue } - // Embed with retry for API embedder embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3) if err != nil { @@ -444,7 +429,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b } return } - // Send result with context awareness select { case resultCh <- batchResult{ @@ -465,7 +449,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b // embedWithRetry attempts embedding with exponential backoff for API embedder func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) { var lastErr error - for attempt := 0; attempt < maxRetries; attempt++ { if attempt > 0 { // Exponential backoff @@ -473,13 +456,11 @@ func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetrie if backoff > 10*time.Second { backoff = 10 * time.Second } - select { case <-time.After(backoff): case <-ctx.Done(): return nil, ctx.Err() } - r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries) } @@ -499,7 +480,6 @@ func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetrie break } } - return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr) } @@ -509,7 +489,6 @@ func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filen // Empty batch, skip return nil } - // Check context before starting select { case <-ctx.Done(): @@ -534,7 +513,6 @@ func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filen r.sendStatusNonBlocking(ErrRAGStatus) return fmt.Errorf("failed to write vectors batch: %w", err) } - r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs)) return nil } diff --git a/rag/storage.go b/rag/storage.go index 1e6b013..62477b6 100644 --- a/rag/storage.go +++ b/rag/storage.go @@ -64,7 +64,6 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { return err } embeddingSize := len(row.Embeddings) - // Start transaction tx, err := vs.sqlxDB.Beginx() if err != nil { @@ -72,7 +71,7 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { } defer func() { if err != nil { - tx.Rollback() + _ = tx.Rollback() } }() @@ -86,14 +85,12 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) return err } - // Insert into FTS table ftsQuery := `INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES (?, ?, ?, ?)` if _, err := tx.Exec(ftsQuery, row.Slug, row.RawText, row.FileName, embeddingSize); err != nil { vs.logger.Error("failed to write to FTS table", "error", err, "slug", row.Slug) return err } - err = tx.Commit() if err != nil { vs.logger.Error("failed to commit transaction", "error", err) @@ -133,7 +130,6 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error { if err != nil { return err } - // Start transaction tx, err := vs.sqlxDB.Beginx() if err != nil { @@ -141,7 +137,7 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error { } defer func() { if err != nil { - tx.Rollback() + _ = tx.Rollback() } }() @@ -161,7 +157,6 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error { vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows)) return err } - // Build batch insert for FTS table ftsPlaceholders := make([]string, 0, len(rows)) ftsArgs := make([]any, 0, len(rows)*4) @@ -170,15 +165,12 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error { ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)") ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize) } - ftsQuery := fmt.Sprintf( - "INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES %s", - strings.Join(ftsPlaceholders, ", "), - ) + ftsQuery := "INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES " + + strings.Join(ftsPlaceholders, ", ") if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil { vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows)) return err } - err = tx.Commit() if err != nil { vs.logger.Error("failed to commit transaction", "error", err) @@ -218,14 +210,12 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec if err != nil { return nil, err } - querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName rows, err := vs.sqlxDB.Query(querySQL) if err != nil { return nil, err } defer rows.Close() - type SearchResult struct { vector models.VectorRow distance float32 @@ -241,7 +231,6 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec vs.logger.Error("failed to scan row", "error", err) continue } - storedEmbeddings := DeserializeVector(embeddingsBlob) similarity := cosineSimilarity(query, storedEmbeddings) distance := 1 - similarity @@ -264,7 +253,6 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec topResults = topResults[:limit] } } - results := make([]models.VectorRow, 0, len(topResults)) for _, result := range topResults { result.vector.Distance = result.distance