Chore: linter complaints

This commit is contained in:
Grail Finder
2026-03-06 19:57:44 +03:00
parent 5f273681df
commit 014e297ae3
4 changed files with 14 additions and 47 deletions

2
.gitignore vendored
View File

@@ -3,6 +3,8 @@
testlog testlog
history/ history/
*.db *.db
*.db-shm
*.db-wal
config.toml config.toml
sysprompts/* sysprompts/*
!sysprompts/alice_bob_carl.json !sysprompts/alice_bob_carl.json

View File

@@ -213,7 +213,6 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
if cudaLibPath == "" { if cudaLibPath == "" {
fmt.Println("WARNING: CUDA provider library not found, will use CPU") fmt.Println("WARNING: CUDA provider library not found, will use CPU")
} }
emb := &ONNXEmbedder{ emb := &ONNXEmbedder{
tokenizerPath: tokenizerPath, tokenizerPath: tokenizerPath,
dims: dims, dims: dims,
@@ -232,7 +231,6 @@ func (e *ONNXEmbedder) ensureInitialized() error {
if e.session != nil { if e.session != nil {
return nil return nil
} }
// Load tokenizer lazily // Load tokenizer lazily
if e.tokenizer == nil { if e.tokenizer == nil {
tok, err := pretrained.FromFile(e.tokenizerPath) tok, err := pretrained.FromFile(e.tokenizerPath)
@@ -241,7 +239,6 @@ func (e *ONNXEmbedder) ensureInitialized() error {
} }
e.tokenizer = tok e.tokenizer = tok
} }
onnxInitOnce.Do(func() { onnxInitOnce.Do(func() {
onnxruntime_go.SetSharedLibraryPath(onnxLibPath) onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
if err := onnxruntime_go.InitializeEnvironment(); err != nil { if err := onnxruntime_go.InitializeEnvironment(); err != nil {
@@ -260,13 +257,14 @@ func (e *ONNXEmbedder) ensureInitialized() error {
if !onnxReady { if !onnxReady {
return errors.New("ONNX runtime not ready") return errors.New("ONNX runtime not ready")
} }
// Create session options // Create session options
opts, err := onnxruntime_go.NewSessionOptions() opts, err := onnxruntime_go.NewSessionOptions()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session options: %w", err) return fmt.Errorf("failed to create session options: %w", err)
} }
defer opts.Destroy() defer func() {
_ = opts.Destroy()
}()
// Try to add CUDA provider // Try to add CUDA provider
useCUDA := cudaLibPath != "" useCUDA := cudaLibPath != ""
@@ -276,7 +274,9 @@ func (e *ONNXEmbedder) ensureInitialized() error {
e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err) e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err)
useCUDA = false useCUDA = false
} else { } else {
defer cudaOpts.Destroy() defer func() {
_ = cudaOpts.Destroy()
}()
if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil { if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil {
e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err) e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err)
useCUDA = false useCUDA = false
@@ -286,7 +286,6 @@ func (e *ONNXEmbedder) ensureInitialized() error {
} }
} }
} }
if useCUDA { if useCUDA {
e.logger.Info("Using CUDA for ONNX inference") e.logger.Info("Using CUDA for ONNX inference")
} else { } else {

View File

@@ -19,10 +19,7 @@ import (
"github.com/neurosnap/sentences/english" "github.com/neurosnap/sentences/english"
) )
const ( const ()
// batchTimeout is the maximum time allowed for embedding a single batch
batchTimeout = 2 * time.Minute
)
var ( var (
// Status messages for TUI integration // Status messages for TUI integration
@@ -102,10 +99,6 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
return rag, nil return rag, nil
} }
func wordCounter(sentence string) int {
return len(strings.Split(strings.TrimSpace(sentence), " "))
}
func createChunks(sentences []string, wordLimit, overlapWords uint32) []string { func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
if len(sentences) == 0 { if len(sentences) == 0 {
return nil return nil
@@ -181,7 +174,6 @@ func (r *RAG) LoadRAG(fpath string) error {
func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
fileText, err := ExtractText(fpath) fileText, err := ExtractText(fpath)
if err != nil { if err != nil {
return err return err
@@ -190,7 +182,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
// Send initial status (non-blocking with retry) // Send initial status (non-blocking with retry)
r.sendStatusNonBlocking(LoadedFileRAGStatus) r.sendStatusNonBlocking(LoadedFileRAGStatus)
tokenizer, err := english.NewSentenceTokenizer(nil) tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil { if err != nil {
return err return err
@@ -210,7 +201,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
if len(paragraphs) == 0 { if len(paragraphs) == 0 {
return errors.New("no valid paragraphs found in file") return errors.New("no valid paragraphs found in file")
} }
totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize
r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize) r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize)
@@ -223,7 +213,7 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
concurrency = 1 concurrency = 1
} }
// If using ONNX embedder, limit concurrency to 1 due to mutex serialization // If using ONNX embedder, limit concurrency to 1 due to mutex serialization
isONNX := false var isONNX bool
if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX { if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX {
concurrency = 1 concurrency = 1
} }
@@ -258,7 +248,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
// Ensure task channel is closed when this goroutine exits // Ensure task channel is closed when this goroutine exits
defer close(taskCh) defer close(taskCh)
r.logger.Debug("task distributor started", "total_batches", totalBatches) r.logger.Debug("task distributor started", "total_batches", totalBatches)
for i := 0; i < totalBatches; i++ { for i := 0; i < totalBatches; i++ {
start := i * r.cfg.RAGBatchSize start := i * r.cfg.RAGBatchSize
end := start + r.cfg.RAGBatchSize end := start + r.cfg.RAGBatchSize
@@ -304,7 +293,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
resultsBuffer := make(map[int]batchResult) resultsBuffer := make(map[int]batchResult)
filename := path.Base(fpath) filename := path.Base(fpath)
batchesProcessed := 0 batchesProcessed := 0
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -382,7 +370,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
break break
} }
} }
r.logger.Debug("finished writing vectors", "batches", batchesProcessed) r.logger.Debug("finished writing vectors", "batches", batchesProcessed)
r.resetIdleTimer() r.resetIdleTimer()
r.sendStatusNonBlocking(FinishedRAGStatus) r.sendStatusNonBlocking(FinishedRAGStatus)
@@ -406,7 +393,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
} }
} }
}() }()
for task := range taskCh { for task := range taskCh {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -432,7 +418,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex) r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex)
continue continue
} }
// Embed with retry for API embedder // Embed with retry for API embedder
embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3) embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3)
if err != nil { if err != nil {
@@ -444,7 +429,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
} }
return return
} }
// Send result with context awareness // Send result with context awareness
select { select {
case resultCh <- batchResult{ case resultCh <- batchResult{
@@ -465,7 +449,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
// embedWithRetry attempts embedding with exponential backoff for API embedder // embedWithRetry attempts embedding with exponential backoff for API embedder
func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) { func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) {
var lastErr error var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ { for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
// Exponential backoff // Exponential backoff
@@ -473,13 +456,11 @@ func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetrie
if backoff > 10*time.Second { if backoff > 10*time.Second {
backoff = 10 * time.Second backoff = 10 * time.Second
} }
select { select {
case <-time.After(backoff): case <-time.After(backoff):
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
} }
r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries) r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries)
} }
@@ -499,7 +480,6 @@ func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetrie
break break
} }
} }
return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr) return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr)
} }
@@ -509,7 +489,6 @@ func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filen
// Empty batch, skip // Empty batch, skip
return nil return nil
} }
// Check context before starting // Check context before starting
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -534,7 +513,6 @@ func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filen
r.sendStatusNonBlocking(ErrRAGStatus) r.sendStatusNonBlocking(ErrRAGStatus)
return fmt.Errorf("failed to write vectors batch: %w", err) return fmt.Errorf("failed to write vectors batch: %w", err)
} }
r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs)) r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs))
return nil return nil
} }

View File

@@ -64,7 +64,6 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
return err return err
} }
embeddingSize := len(row.Embeddings) embeddingSize := len(row.Embeddings)
// Start transaction // Start transaction
tx, err := vs.sqlxDB.Beginx() tx, err := vs.sqlxDB.Beginx()
if err != nil { if err != nil {
@@ -72,7 +71,7 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
} }
defer func() { defer func() {
if err != nil { if err != nil {
tx.Rollback() _ = tx.Rollback()
} }
}() }()
@@ -86,14 +85,12 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
return err return err
} }
// Insert into FTS table // Insert into FTS table
ftsQuery := `INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES (?, ?, ?, ?)` ftsQuery := `INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES (?, ?, ?, ?)`
if _, err := tx.Exec(ftsQuery, row.Slug, row.RawText, row.FileName, embeddingSize); err != nil { if _, err := tx.Exec(ftsQuery, row.Slug, row.RawText, row.FileName, embeddingSize); err != nil {
vs.logger.Error("failed to write to FTS table", "error", err, "slug", row.Slug) vs.logger.Error("failed to write to FTS table", "error", err, "slug", row.Slug)
return err return err
} }
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
vs.logger.Error("failed to commit transaction", "error", err) vs.logger.Error("failed to commit transaction", "error", err)
@@ -133,7 +130,6 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
if err != nil { if err != nil {
return err return err
} }
// Start transaction // Start transaction
tx, err := vs.sqlxDB.Beginx() tx, err := vs.sqlxDB.Beginx()
if err != nil { if err != nil {
@@ -141,7 +137,7 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
} }
defer func() { defer func() {
if err != nil { if err != nil {
tx.Rollback() _ = tx.Rollback()
} }
}() }()
@@ -161,7 +157,6 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows)) vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows))
return err return err
} }
// Build batch insert for FTS table // Build batch insert for FTS table
ftsPlaceholders := make([]string, 0, len(rows)) ftsPlaceholders := make([]string, 0, len(rows))
ftsArgs := make([]any, 0, len(rows)*4) ftsArgs := make([]any, 0, len(rows)*4)
@@ -170,15 +165,12 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)") ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)")
ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize) ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize)
} }
ftsQuery := fmt.Sprintf( ftsQuery := "INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES " +
"INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES %s", strings.Join(ftsPlaceholders, ", ")
strings.Join(ftsPlaceholders, ", "),
)
if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil { if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil {
vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows)) vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows))
return err return err
} }
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
vs.logger.Error("failed to commit transaction", "error", err) vs.logger.Error("failed to commit transaction", "error", err)
@@ -218,14 +210,12 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec
if err != nil { if err != nil {
return nil, err return nil, err
} }
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
rows, err := vs.sqlxDB.Query(querySQL) rows, err := vs.sqlxDB.Query(querySQL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
type SearchResult struct { type SearchResult struct {
vector models.VectorRow vector models.VectorRow
distance float32 distance float32
@@ -241,7 +231,6 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec
vs.logger.Error("failed to scan row", "error", err) vs.logger.Error("failed to scan row", "error", err)
continue continue
} }
storedEmbeddings := DeserializeVector(embeddingsBlob) storedEmbeddings := DeserializeVector(embeddingsBlob)
similarity := cosineSimilarity(query, storedEmbeddings) similarity := cosineSimilarity(query, storedEmbeddings)
distance := 1 - similarity distance := 1 - similarity
@@ -264,7 +253,6 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec
topResults = topResults[:limit] topResults = topResults[:limit]
} }
} }
results := make([]models.VectorRow, 0, len(topResults)) results := make([]models.VectorRow, 0, len(topResults))
for _, result := range topResults { for _, result := range topResults {
result.vector.Distance = result.distance result.vector.Distance = result.distance