Chore: linter complaints
This commit is contained in:
@@ -213,7 +213,6 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
|
||||
if cudaLibPath == "" {
|
||||
fmt.Println("WARNING: CUDA provider library not found, will use CPU")
|
||||
}
|
||||
|
||||
emb := &ONNXEmbedder{
|
||||
tokenizerPath: tokenizerPath,
|
||||
dims: dims,
|
||||
@@ -232,7 +231,6 @@ func (e *ONNXEmbedder) ensureInitialized() error {
|
||||
if e.session != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load tokenizer lazily
|
||||
if e.tokenizer == nil {
|
||||
tok, err := pretrained.FromFile(e.tokenizerPath)
|
||||
@@ -241,7 +239,6 @@ func (e *ONNXEmbedder) ensureInitialized() error {
|
||||
}
|
||||
e.tokenizer = tok
|
||||
}
|
||||
|
||||
onnxInitOnce.Do(func() {
|
||||
onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
|
||||
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
|
||||
@@ -260,13 +257,14 @@ func (e *ONNXEmbedder) ensureInitialized() error {
|
||||
if !onnxReady {
|
||||
return errors.New("ONNX runtime not ready")
|
||||
}
|
||||
|
||||
// Create session options
|
||||
opts, err := onnxruntime_go.NewSessionOptions()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session options: %w", err)
|
||||
}
|
||||
defer opts.Destroy()
|
||||
defer func() {
|
||||
_ = opts.Destroy()
|
||||
}()
|
||||
|
||||
// Try to add CUDA provider
|
||||
useCUDA := cudaLibPath != ""
|
||||
@@ -276,7 +274,9 @@ func (e *ONNXEmbedder) ensureInitialized() error {
|
||||
e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err)
|
||||
useCUDA = false
|
||||
} else {
|
||||
defer cudaOpts.Destroy()
|
||||
defer func() {
|
||||
_ = cudaOpts.Destroy()
|
||||
}()
|
||||
if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil {
|
||||
e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err)
|
||||
useCUDA = false
|
||||
@@ -286,7 +286,6 @@ func (e *ONNXEmbedder) ensureInitialized() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if useCUDA {
|
||||
e.logger.Info("Using CUDA for ONNX inference")
|
||||
} else {
|
||||
|
||||
26
rag/rag.go
26
rag/rag.go
@@ -19,10 +19,7 @@ import (
|
||||
"github.com/neurosnap/sentences/english"
|
||||
)
|
||||
|
||||
const (
|
||||
// batchTimeout is the maximum time allowed for embedding a single batch
|
||||
batchTimeout = 2 * time.Minute
|
||||
)
|
||||
const ()
|
||||
|
||||
var (
|
||||
// Status messages for TUI integration
|
||||
@@ -102,10 +99,6 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||
return rag, nil
|
||||
}
|
||||
|
||||
func wordCounter(sentence string) int {
|
||||
return len(strings.Split(strings.TrimSpace(sentence), " "))
|
||||
}
|
||||
|
||||
func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
|
||||
if len(sentences) == 0 {
|
||||
return nil
|
||||
@@ -181,7 +174,6 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
fileText, err := ExtractText(fpath)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -190,7 +182,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
|
||||
// Send initial status (non-blocking with retry)
|
||||
r.sendStatusNonBlocking(LoadedFileRAGStatus)
|
||||
|
||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -210,7 +201,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
if len(paragraphs) == 0 {
|
||||
return errors.New("no valid paragraphs found in file")
|
||||
}
|
||||
|
||||
totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize
|
||||
r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize)
|
||||
|
||||
@@ -223,7 +213,7 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
concurrency = 1
|
||||
}
|
||||
// If using ONNX embedder, limit concurrency to 1 due to mutex serialization
|
||||
isONNX := false
|
||||
var isONNX bool
|
||||
if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX {
|
||||
concurrency = 1
|
||||
}
|
||||
@@ -258,7 +248,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
// Ensure task channel is closed when this goroutine exits
|
||||
defer close(taskCh)
|
||||
r.logger.Debug("task distributor started", "total_batches", totalBatches)
|
||||
|
||||
for i := 0; i < totalBatches; i++ {
|
||||
start := i * r.cfg.RAGBatchSize
|
||||
end := start + r.cfg.RAGBatchSize
|
||||
@@ -304,7 +293,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
resultsBuffer := make(map[int]batchResult)
|
||||
filename := path.Base(fpath)
|
||||
batchesProcessed := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -382,7 +370,6 @@ func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Debug("finished writing vectors", "batches", batchesProcessed)
|
||||
r.resetIdleTimer()
|
||||
r.sendStatusNonBlocking(FinishedRAGStatus)
|
||||
@@ -406,7 +393,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for task := range taskCh {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -432,7 +418,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
|
||||
r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex)
|
||||
continue
|
||||
}
|
||||
|
||||
// Embed with retry for API embedder
|
||||
embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3)
|
||||
if err != nil {
|
||||
@@ -444,7 +429,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Send result with context awareness
|
||||
select {
|
||||
case resultCh <- batchResult{
|
||||
@@ -465,7 +449,6 @@ func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan b
|
||||
// embedWithRetry attempts embedding with exponential backoff for API embedder
|
||||
func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Exponential backoff
|
||||
@@ -473,13 +456,11 @@ func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetrie
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries)
|
||||
}
|
||||
|
||||
@@ -499,7 +480,6 @@ func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetrie
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
@@ -509,7 +489,6 @@ func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filen
|
||||
// Empty batch, skip
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check context before starting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -534,7 +513,6 @@ func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filen
|
||||
r.sendStatusNonBlocking(ErrRAGStatus)
|
||||
return fmt.Errorf("failed to write vectors batch: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,7 +64,6 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
||||
return err
|
||||
}
|
||||
embeddingSize := len(row.Embeddings)
|
||||
|
||||
// Start transaction
|
||||
tx, err := vs.sqlxDB.Beginx()
|
||||
if err != nil {
|
||||
@@ -72,7 +71,7 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -86,14 +85,12 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
||||
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert into FTS table
|
||||
ftsQuery := `INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES (?, ?, ?, ?)`
|
||||
if _, err := tx.Exec(ftsQuery, row.Slug, row.RawText, row.FileName, embeddingSize); err != nil {
|
||||
vs.logger.Error("failed to write to FTS table", "error", err, "slug", row.Slug)
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
vs.logger.Error("failed to commit transaction", "error", err)
|
||||
@@ -133,7 +130,6 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx, err := vs.sqlxDB.Beginx()
|
||||
if err != nil {
|
||||
@@ -141,7 +137,7 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -161,7 +157,6 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
|
||||
vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows))
|
||||
return err
|
||||
}
|
||||
|
||||
// Build batch insert for FTS table
|
||||
ftsPlaceholders := make([]string, 0, len(rows))
|
||||
ftsArgs := make([]any, 0, len(rows)*4)
|
||||
@@ -170,15 +165,12 @@ func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
|
||||
ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)")
|
||||
ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize)
|
||||
}
|
||||
ftsQuery := fmt.Sprintf(
|
||||
"INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES %s",
|
||||
strings.Join(ftsPlaceholders, ", "),
|
||||
)
|
||||
ftsQuery := "INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES " +
|
||||
strings.Join(ftsPlaceholders, ", ")
|
||||
if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil {
|
||||
vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows))
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
vs.logger.Error("failed to commit transaction", "error", err)
|
||||
@@ -218,14 +210,12 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
||||
rows, err := vs.sqlxDB.Query(querySQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type SearchResult struct {
|
||||
vector models.VectorRow
|
||||
distance float32
|
||||
@@ -241,7 +231,6 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec
|
||||
vs.logger.Error("failed to scan row", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
storedEmbeddings := DeserializeVector(embeddingsBlob)
|
||||
similarity := cosineSimilarity(query, storedEmbeddings)
|
||||
distance := 1 - similarity
|
||||
@@ -264,7 +253,6 @@ func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.Vec
|
||||
topResults = topResults[:limit]
|
||||
}
|
||||
}
|
||||
|
||||
results := make([]models.VectorRow, 0, len(topResults))
|
||||
for _, result := range topResults {
|
||||
result.vector.Distance = result.distance
|
||||
|
||||
Reference in New Issue
Block a user