From 62ec55505ca07701ee6a976895d910b051e725b9 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Fri, 6 Mar 2026 13:17:49 +0300 Subject: [PATCH] Enha (rag): query each doc --- rag/rag.go | 46 ++++++++++++++++++++++++++++++++++++++++++---- rag/storage.go | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/rag/rag.go b/rag/rag.go index 4e11a0d..9271b60 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -286,10 +286,13 @@ func (r *RAG) RefineQuery(query string) string { return original } query = strings.ToLower(query) - for _, stopWord := range stopWords { - wordPattern := `\b` + stopWord + `\b` - re := regexp.MustCompile(wordPattern) - query = re.ReplaceAllString(query, "") + words := strings.Fields(query) + if len(words) >= 3 { + for _, stopWord := range stopWords { + wordPattern := `\b` + stopWord + `\b` + re := regexp.MustCompile(wordPattern) + query = re.ReplaceAllString(query, "") + } } query = strings.TrimSpace(query) if len(query) < 5 { @@ -340,6 +343,36 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if len(parts) == 0 { return variations } + // Get loaded filenames to filter out filename terms + filenames, err := r.storage.ListFiles() + if err == nil && len(filenames) > 0 { + // Convert to lowercase for case-insensitive matching + lowerFilenames := make([]string, len(filenames)) + for i, f := range filenames { + lowerFilenames[i] = strings.ToLower(f) + } + filteredParts := make([]string, 0, len(parts)) + for _, part := range parts { + partLower := strings.ToLower(part) + skip := false + for _, fn := range lowerFilenames { + if strings.Contains(fn, partLower) || strings.Contains(partLower, fn) { + skip = true + break + } + } + if !skip { + filteredParts = append(filteredParts, part) + } + } + // If filteredParts not empty and different from original, add filtered query + if len(filteredParts) > 0 && len(filteredParts) != len(parts) { + filteredQuery := strings.Join(filteredParts, " ") + if len(filteredQuery) >= 5 { + variations = append(variations, filteredQuery) + } + } + } if len(parts) >= 2 { trimmed := strings.Join(parts[:len(parts)-1], " ") if len(trimmed) >= 5 { @@ -403,9 +436,14 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V }) unique := make([]models.VectorRow, 0) seen := make(map[string]bool) + fileCounts := make(map[string]int) for i := range scored { if !seen[scored[i].row.Slug] { + if fileCounts[scored[i].row.FileName] >= 2 { + continue + } seen[scored[i].row.Slug] = true + fileCounts[scored[i].row.FileName]++ unique = append(unique, scored[i].row) } } diff --git a/rag/storage.go b/rag/storage.go index 08e9d2a..110cea2 100644 --- a/rag/storage.go +++ b/rag/storage.go @@ -1,6 +1,7 @@ package rag import ( + "database/sql" "encoding/binary" "fmt" "gf-lt/models" @@ -221,11 +222,41 @@ func (vs *VectorStorage) SearchKeyword(query string, limit int) ([]models.Vector WHERE fts_embeddings MATCH ? ORDER BY score LIMIT ?` + + // Try original query first rows, err := vs.sqlxDB.Query(ftsQuery, query, limit) if err != nil { return nil, fmt.Errorf("FTS search failed: %w", err) } - defer rows.Close() + results, err := vs.scanRows(rows) + rows.Close() + if err != nil { + return nil, err + } + + // If no results and query contains multiple terms, try OR fallback + if len(results) == 0 && strings.Contains(query, " ") && !strings.Contains(strings.ToUpper(query), " OR ") { + // Build OR query: term1 OR term2 OR term3 + terms := strings.Fields(query) + if len(terms) > 1 { + orQuery := strings.Join(terms, " OR ") + rows, err := vs.sqlxDB.Query(ftsQuery, orQuery, limit) + if err != nil { + // Return original empty results rather than error + return results, nil + } + orResults, err := vs.scanRows(rows) + rows.Close() + if err == nil { + results = orResults + } + } + } + return results, nil +} + +// scanRows converts SQL rows to VectorRow slice +func (vs *VectorStorage) scanRows(rows *sql.Rows) ([]models.VectorRow, error) { var results []models.VectorRow for rows.Next() { var slug, rawText, fileName string