Enha (rag): query each doc

This commit is contained in:
Grail Finder
2026-03-06 13:17:49 +03:00
parent f9866bcf5a
commit 62ec55505c
2 changed files with 74 additions and 5 deletions

View File

@@ -286,10 +286,13 @@ func (r *RAG) RefineQuery(query string) string {
return original return original
} }
query = strings.ToLower(query) query = strings.ToLower(query)
for _, stopWord := range stopWords { words := strings.Fields(query)
wordPattern := `\b` + stopWord + `\b` if len(words) >= 3 {
re := regexp.MustCompile(wordPattern) for _, stopWord := range stopWords {
query = re.ReplaceAllString(query, "") wordPattern := `\b` + stopWord + `\b`
re := regexp.MustCompile(wordPattern)
query = re.ReplaceAllString(query, "")
}
} }
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) < 5 { if len(query) < 5 {
@@ -340,6 +343,36 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if len(parts) == 0 { if len(parts) == 0 {
return variations return variations
} }
// Get loaded filenames to filter out filename terms
filenames, err := r.storage.ListFiles()
if err == nil && len(filenames) > 0 {
// Convert to lowercase for case-insensitive matching
lowerFilenames := make([]string, len(filenames))
for i, f := range filenames {
lowerFilenames[i] = strings.ToLower(f)
}
filteredParts := make([]string, 0, len(parts))
for _, part := range parts {
partLower := strings.ToLower(part)
skip := false
for _, fn := range lowerFilenames {
if strings.Contains(fn, partLower) || strings.Contains(partLower, fn) {
skip = true
break
}
}
if !skip {
filteredParts = append(filteredParts, part)
}
}
// If filteredParts not empty and different from original, add filtered query
if len(filteredParts) > 0 && len(filteredParts) != len(parts) {
filteredQuery := strings.Join(filteredParts, " ")
if len(filteredQuery) >= 5 {
variations = append(variations, filteredQuery)
}
}
}
if len(parts) >= 2 { if len(parts) >= 2 {
trimmed := strings.Join(parts[:len(parts)-1], " ") trimmed := strings.Join(parts[:len(parts)-1], " ")
if len(trimmed) >= 5 { if len(trimmed) >= 5 {
@@ -403,9 +436,14 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
}) })
unique := make([]models.VectorRow, 0) unique := make([]models.VectorRow, 0)
seen := make(map[string]bool) seen := make(map[string]bool)
fileCounts := make(map[string]int)
for i := range scored { for i := range scored {
if !seen[scored[i].row.Slug] { if !seen[scored[i].row.Slug] {
if fileCounts[scored[i].row.FileName] >= 2 {
continue
}
seen[scored[i].row.Slug] = true seen[scored[i].row.Slug] = true
fileCounts[scored[i].row.FileName]++
unique = append(unique, scored[i].row) unique = append(unique, scored[i].row)
} }
} }

View File

@@ -1,6 +1,7 @@
package rag package rag
import ( import (
"database/sql"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"gf-lt/models" "gf-lt/models"
@@ -221,11 +222,41 @@ func (vs *VectorStorage) SearchKeyword(query string, limit int) ([]models.Vector
WHERE fts_embeddings MATCH ? WHERE fts_embeddings MATCH ?
ORDER BY score ORDER BY score
LIMIT ?` LIMIT ?`
// Try original query first
rows, err := vs.sqlxDB.Query(ftsQuery, query, limit) rows, err := vs.sqlxDB.Query(ftsQuery, query, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("FTS search failed: %w", err) return nil, fmt.Errorf("FTS search failed: %w", err)
} }
defer rows.Close() results, err := vs.scanRows(rows)
rows.Close()
if err != nil {
return nil, err
}
// If no results and query contains multiple terms, try OR fallback
if len(results) == 0 && strings.Contains(query, " ") && !strings.Contains(strings.ToUpper(query), " OR ") {
// Build OR query: term1 OR term2 OR term3
terms := strings.Fields(query)
if len(terms) > 1 {
orQuery := strings.Join(terms, " OR ")
rows, err := vs.sqlxDB.Query(ftsQuery, orQuery, limit)
if err != nil {
// Return original empty results rather than error
return results, nil
}
orResults, err := vs.scanRows(rows)
rows.Close()
if err == nil {
results = orResults
}
}
}
return results, nil
}
// scanRows converts SQL rows to VectorRow slice
func (vs *VectorStorage) scanRows(rows *sql.Rows) ([]models.VectorRow, error) {
var results []models.VectorRow var results []models.VectorRow
for rows.Next() { for rows.Next() {
var slug, rawText, fileName string var slug, rawText, fileName string