diff --git a/config.example.toml b/config.example.toml index 15c8ca6..1698189 100644 --- a/config.example.toml +++ b/config.example.toml @@ -29,6 +29,7 @@ AutoCleanToolCallsFromCtx = false # rag settings RAGBatchSize = 1 RAGWordLimit = 80 +RAGOverlapWords = 16 RAGDir = "ragimport" # extra tts TTS_ENABLED = false diff --git a/config/config.go b/config/config.go index 29d5744..fab3237 100644 --- a/config/config.go +++ b/config/config.go @@ -40,9 +40,10 @@ type Config struct { EmbedTokenizerPath string `toml:"EmbedTokenizerPath"` EmbedDims int `toml:"EmbedDims"` // rag settings - RAGDir string `toml:"RAGDir"` - RAGBatchSize int `toml:"RAGBatchSize"` - RAGWordLimit uint32 `toml:"RAGWordLimit"` + RAGDir string `toml:"RAGDir"` + RAGBatchSize int `toml:"RAGBatchSize"` + RAGWordLimit uint32 `toml:"RAGWordLimit"` + RAGOverlapWords uint32 `toml:"RAGOverlapWords"` // deepseek DeepSeekChatAPI string `toml:"DeepSeekChatAPI"` DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"` diff --git a/rag/rag.go b/rag/rag.go index d64a3e1..4e11a0d 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -73,6 +73,74 @@ func wordCounter(sentence string) int { return len(strings.Split(strings.TrimSpace(sentence), " ")) } +func createChunks(sentences []string, wordLimit, overlapWords uint32) []string { + if len(sentences) == 0 { + return nil + } + if overlapWords >= wordLimit { + overlapWords = wordLimit / 2 + } + var chunks []string + i := 0 + for i < len(sentences) { + var chunkWords []string + wordCount := 0 + j := i + for j < len(sentences) && wordCount <= int(wordLimit) { + sentence := sentences[j] + words := strings.Fields(sentence) + chunkWords = append(chunkWords, sentence) + wordCount += len(words) + j++ + // If this sentence alone exceeds limit, still include it and stop + if wordCount > int(wordLimit) { + break + } + } + if len(chunkWords) == 0 { + break + } + chunk := strings.Join(chunkWords, " ") + chunks = append(chunks, chunk) + if j >= len(sentences) { + break + } + // Move i forward by skipping overlap + if overlapWords == 0 { + i = j + continue + } + // Calculate how many sentences to skip to achieve overlapWords + overlapRemaining := int(overlapWords) + newI := i + for newI < j && overlapRemaining > 0 { + words := len(strings.Fields(sentences[newI])) + overlapRemaining -= words + if overlapRemaining >= 0 { + newI++ + } + } + if newI == i { + newI = j + } + i = newI + } + return chunks +} + +func sanitizeFTSQuery(query string) string { + // Remove double quotes and other problematic characters for FTS5 + query = strings.ReplaceAll(query, "\"", " ") + query = strings.ReplaceAll(query, "'", " ") + query = strings.ReplaceAll(query, ";", " ") + query = strings.ReplaceAll(query, "\\", " ") + query = strings.TrimSpace(query) + if query == "" { + return "*" // match all + } + return query +} + func (r *RAG) LoadRAG(fpath string) error { r.mu.Lock() defer r.mu.Unlock() @@ -95,31 +163,8 @@ func (r *RAG) LoadRAG(fpath string) error { for i, s := range sentences { sents[i] = s.Text } - // Group sentences into paragraphs based on word limit - paragraphs := []string{} - par := strings.Builder{} - for i := 0; i < len(sents); i++ { - if strings.TrimSpace(sents[i]) != "" { - if par.Len() > 0 { - par.WriteString(" ") - } - par.WriteString(sents[i]) - } - if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { - paragraph := strings.TrimSpace(par.String()) - if paragraph != "" { - paragraphs = append(paragraphs, paragraph) - } - par.Reset() - } - } - // Handle any remaining content in the paragraph buffer - if par.Len() > 0 { - paragraph := strings.TrimSpace(par.String()) - if paragraph != "" { - paragraphs = append(paragraphs, paragraph) - } - } + // Create chunks with overlap + paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords) // Adjust batch size if needed if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 { r.cfg.RAGBatchSize = len(paragraphs) @@ -205,9 +250,15 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { return r.embedder.Embed(line) } -func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { +func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) { r.resetIdleTimer() - return r.storage.SearchClosest(emb.Embedding) + return r.storage.SearchClosest(emb.Embedding, limit) +} + +func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) { + r.resetIdleTimer() + sanitized := sanitizeFTSQuery(query) + return r.storage.SearchKeyword(sanitized, limit) } func (r *RAG) ListLoaded() ([]string, error) { @@ -393,7 +444,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string Embedding: emb, Index: 0, } - topResults, err := r.SearchEmb(embResp) + topResults, err := r.SearchEmb(embResp, 1) if err != nil { r.logger.Error("failed to search for synthesis context", "error", err) return "", err @@ -422,7 +473,9 @@ func truncateString(s string, maxLen int) string { func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) - allResults := make([]models.VectorRow, 0) + + // Collect embedding search results from all variations + var embResults []models.VectorRow seen := make(map[string]bool) for _, q := range variations { emb, err := r.LineToVector(q) @@ -430,29 +483,78 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { r.logger.Error("failed to embed query variation", "error", err, "query", q) continue } - embResp := &models.EmbeddingResp{ Embedding: emb, Index: 0, } - - results, err := r.SearchEmb(embResp) + results, err := r.SearchEmb(embResp, limit*2) // Get more candidates if err != nil { r.logger.Error("failed to search embeddings", "error", err, "query", q) continue } - for _, row := range results { if !seen[row.Slug] { seen[row.Slug] = true - allResults = append(allResults, row) + embResults = append(embResults, row) } } } - reranked := r.RerankResults(allResults, query) - if len(reranked) > limit { - reranked = reranked[:limit] + // Sort embedding results by distance (lower is better) + sort.Slice(embResults, func(i, j int) bool { + return embResults[i].Distance < embResults[j].Distance + }) + + // Perform keyword search + kwResults, err := r.SearchKeyword(refined, limit*2) + if err != nil { + r.logger.Warn("keyword search failed, using only embeddings", "error", err) + kwResults = nil } + // Sort keyword results by distance (already sorted by BM25 score) + // kwResults already sorted by distance (lower is better) + + // Combine using Reciprocal Rank Fusion (RRF) + const rrfK = 60 + type scoredRow struct { + row models.VectorRow + score float64 + } + scoreMap := make(map[string]float64) + // Add embedding results + for rank, row := range embResults { + score := 1.0 / (float64(rank) + rrfK) + scoreMap[row.Slug] += score + } + // Add keyword results + for rank, row := range kwResults { + score := 1.0 / (float64(rank) + rrfK) + scoreMap[row.Slug] += score + // Ensure row exists in combined results + if _, exists := seen[row.Slug]; !exists { + embResults = append(embResults, row) + } + } + // Create slice of scored rows + scoredRows := make([]scoredRow, 0, len(embResults)) + for _, row := range embResults { + score := scoreMap[row.Slug] + scoredRows = append(scoredRows, scoredRow{row: row, score: score}) + } + // Sort by descending RRF score + sort.Slice(scoredRows, func(i, j int) bool { + return scoredRows[i].score > scoredRows[j].score + }) + // Take top limit + if len(scoredRows) > limit { + scoredRows = scoredRows[:limit] + } + // Convert back to VectorRow + finalResults := make([]models.VectorRow, len(scoredRows)) + for i, sr := range scoredRows { + finalResults[i] = sr.row + } + // Apply reranking heuristics + reranked := r.RerankResults(finalResults, query) return reranked, nil } diff --git a/rag/storage.go b/rag/storage.go index 52f6859..08e9d2a 100644 --- a/rag/storage.go +++ b/rag/storage.go @@ -62,6 +62,18 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { if err != nil { return err } + embeddingSize := len(row.Embeddings) + + // Start transaction + tx, err := vs.sqlxDB.Beginx() + if err != nil { + return err + } + defer func() { + if err != nil { + tx.Rollback() + } + }() // Serialize the embeddings to binary serializedEmbeddings := SerializeVector(row.Embeddings) @@ -69,10 +81,23 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName, ) - if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { + if _, err := tx.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) return err } + + // Insert into FTS table + ftsQuery := `INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES (?, ?, ?, ?)` + if _, err := tx.Exec(ftsQuery, row.Slug, row.RawText, row.FileName, embeddingSize); err != nil { + vs.logger.Error("failed to write to FTS table", "error", err, "slug", row.Slug) + return err + } + + err = tx.Commit() + if err != nil { + vs.logger.Error("failed to commit transaction", "error", err) + return err + } return nil } @@ -98,16 +123,15 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) { } // SearchClosest finds vectors closest to the query vector using efficient cosine similarity calculation -func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, error) { +func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.VectorRow, error) { + if limit <= 0 { + limit = 10 + } tableName, err := vs.getTableName(query) if err != nil { return nil, err } - // For better performance, instead of loading all vectors at once, - // we'll implement batching and potentially add L2 distance-based pre-filtering - // since cosine similarity is related to L2 distance for normalized vectors - querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName rows, err := vs.sqlxDB.Query(querySQL) if err != nil { @@ -115,13 +139,11 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err } defer rows.Close() - // Use a min-heap or simple slice to keep track of top 3 closest vectors type SearchResult struct { vector models.VectorRow distance float32 } var topResults []SearchResult - // Process vectors one by one to avoid loading everything into memory for rows.Next() { var ( embeddingsBlob []byte @@ -134,10 +156,8 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err } storedEmbeddings := DeserializeVector(embeddingsBlob) - - // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) similarity := cosineSimilarity(query, storedEmbeddings) - distance := 1 - similarity // Convert to distance where 0 is most similar + distance := 1 - similarity result := SearchResult{ vector: models.VectorRow{ @@ -149,20 +169,15 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err distance: distance, } - // Add to top results and maintain only top 3 topResults = append(topResults, result) - - // Sort and keep only top 3 sort.Slice(topResults, func(i, j int) bool { return topResults[i].distance < topResults[j].distance }) - - if len(topResults) > 3 { - topResults = topResults[:3] // Keep only closest 3 + if len(topResults) > limit { + topResults = topResults[:limit] } } - // Convert back to VectorRow slice results := make([]models.VectorRow, 0, len(topResults)) for _, result := range topResults { result.vector.Distance = result.distance @@ -171,6 +186,70 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err return results, nil } +// GetVectorBySlug retrieves a vector row by its slug +func (vs *VectorStorage) GetVectorBySlug(slug string) (*models.VectorRow, error) { + embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} + for _, size := range embeddingSizes { + table := fmt.Sprintf("embeddings_%d", size) + query := fmt.Sprintf("SELECT embeddings, slug, raw_text, filename FROM %s WHERE slug = ?", table) + row := vs.sqlxDB.QueryRow(query, slug) + var ( + embeddingsBlob []byte + retrievedSlug, rawText, fileName string + ) + if err := row.Scan(&embeddingsBlob, &retrievedSlug, &rawText, &fileName); err != nil { + // No row in this table, continue to next size + continue + } + storedEmbeddings := DeserializeVector(embeddingsBlob) + return &models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: retrievedSlug, + RawText: rawText, + FileName: fileName, + }, nil + } + return nil, fmt.Errorf("vector with slug %s not found", slug) +} + +// SearchKeyword performs full-text search using FTS5 +func (vs *VectorStorage) SearchKeyword(query string, limit int) ([]models.VectorRow, error) { + // Use FTS5 bm25 ranking. bm25 returns negative values where more negative is better. + // We'll order by bm25 (ascending) and limit. + ftsQuery := `SELECT slug, raw_text, filename, bm25(fts_embeddings) as score + FROM fts_embeddings + WHERE fts_embeddings MATCH ? + ORDER BY score + LIMIT ?` + rows, err := vs.sqlxDB.Query(ftsQuery, query, limit) + if err != nil { + return nil, fmt.Errorf("FTS search failed: %w", err) + } + defer rows.Close() + var results []models.VectorRow + for rows.Next() { + var slug, rawText, fileName string + var score float64 + if err := rows.Scan(&slug, &rawText, &fileName, &score); err != nil { + vs.logger.Error("failed to scan FTS row", "error", err) + continue + } + // Convert BM25 score to distance-like metric (lower is better) + // BM25 is negative, more negative is better. We'll normalize to positive distance. + distance := float32(-score) // Make positive (since score is negative) + if distance < 0 { + distance = 0 + } + results = append(results, models.VectorRow{ + Slug: slug, + RawText: rawText, + FileName: fileName, + Distance: distance, + }) + } + return results, nil +} + // ListFiles returns a list of all loaded files func (vs *VectorStorage) ListFiles() ([]string, error) { fileLists := make([][]string, 0) @@ -215,6 +294,10 @@ func (vs *VectorStorage) ListFiles() ([]string, error) { // RemoveEmbByFileName removes all embeddings associated with a specific filename func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { var errors []string + // Delete from FTS table first + if _, err := vs.sqlxDB.Exec("DELETE FROM fts_embeddings WHERE filename = ?", filename); err != nil { + errors = append(errors, err.Error()) + } embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} for _, size := range embeddingSizes { table := fmt.Sprintf("embeddings_%d", size) diff --git a/storage/migrations/003_add_fts.down.sql b/storage/migrations/003_add_fts.down.sql new file mode 100644 index 0000000..e565fd5 --- /dev/null +++ b/storage/migrations/003_add_fts.down.sql @@ -0,0 +1,2 @@ +-- Drop FTS5 virtual table +DROP TABLE IF EXISTS fts_embeddings; \ No newline at end of file diff --git a/storage/migrations/003_add_fts.up.sql b/storage/migrations/003_add_fts.up.sql new file mode 100644 index 0000000..114586a --- /dev/null +++ b/storage/migrations/003_add_fts.up.sql @@ -0,0 +1,15 @@ +-- Create FTS5 virtual table for full-text search +CREATE VIRTUAL TABLE IF NOT EXISTS fts_embeddings USING fts5( + slug UNINDEXED, + raw_text, + filename UNINDEXED, + embedding_size UNINDEXED, + tokenize='porter unicode61' -- Use porter stemmer and unicode61 tokenizer +); + +-- Create triggers to maintain FTS table when embeddings are inserted/deleted +-- Note: We'll handle inserts/deletes programmatically for simplicity +-- but triggers could be added here if needed. + +-- Indexes for performance (FTS5 manages its own indexes) +-- No additional indexes needed for FTS5 virtual table. \ No newline at end of file diff --git a/storage/migrations/004_populate_fts.down.sql b/storage/migrations/004_populate_fts.down.sql new file mode 100644 index 0000000..2b5c756 --- /dev/null +++ b/storage/migrations/004_populate_fts.down.sql @@ -0,0 +1,2 @@ +-- Clear FTS table (optional) +DELETE FROM fts_embeddings; \ No newline at end of file diff --git a/storage/migrations/004_populate_fts.up.sql b/storage/migrations/004_populate_fts.up.sql new file mode 100644 index 0000000..1d1b16a --- /dev/null +++ b/storage/migrations/004_populate_fts.up.sql @@ -0,0 +1,26 @@ +-- Populate FTS table with existing embeddings +DELETE FROM fts_embeddings; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 384 FROM embeddings_384; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 768 FROM embeddings_768; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 1024 FROM embeddings_1024; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 1536 FROM embeddings_1536; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 2048 FROM embeddings_2048; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 3072 FROM embeddings_3072; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 4096 FROM embeddings_4096; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 5120 FROM embeddings_5120; \ No newline at end of file diff --git a/storage/vector.go b/storage/vector.go index 75f5c9a..e3bbb89 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "gf-lt/models" + "sort" "unsafe" "github.com/jmoiron/sqlx" @@ -11,7 +12,7 @@ import ( type VectorRepo interface { WriteVector(*models.VectorRow) error - SearchClosest(q []float32) ([]models.VectorRow, error) + SearchClosest(q []float32, limit int) ([]models.VectorRow, error) ListFiles() ([]string, error) RemoveEmbByFileName(filename string) error DB() *sqlx.DB @@ -79,7 +80,7 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { return err } -func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { +func (p ProviderSQL) SearchClosest(q []float32, limit int) ([]models.VectorRow, error) { tableName, err := fetchTableName(q) if err != nil { return nil, err @@ -94,7 +95,7 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { vector models.VectorRow distance float32 } - var topResults []SearchResult + var allResults []SearchResult for rows.Next() { var ( embeddingsBlob []byte @@ -119,28 +120,19 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { }, distance: distance, } - - // Add to top results and maintain only top results - topResults = append(topResults, result) - - // Sort and keep only top results - // We'll keep the top 3 closest vectors - if len(topResults) > 3 { - // Simple sort and truncate to maintain only 3 best matches - for i := 0; i < len(topResults); i++ { - for j := i + 1; j < len(topResults); j++ { - if topResults[i].distance > topResults[j].distance { - topResults[i], topResults[j] = topResults[j], topResults[i] - } - } - } - topResults = topResults[:3] - } + allResults = append(allResults, result) + } + // Sort by distance + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].distance < allResults[j].distance + }) + // Truncate to limit + if len(allResults) > limit { + allResults = allResults[:limit] } - // Convert back to VectorRow slice - results := make([]models.VectorRow, len(topResults)) - for i, result := range topResults { + results := make([]models.VectorRow, len(allResults)) + for i, result := range allResults { result.vector.Distance = result.distance results[i] = result.vector }