Enha (rag): query each doc
This commit is contained in:
38
rag/rag.go
38
rag/rag.go
@@ -286,11 +286,14 @@ func (r *RAG) RefineQuery(query string) string {
|
|||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
query = strings.ToLower(query)
|
query = strings.ToLower(query)
|
||||||
|
words := strings.Fields(query)
|
||||||
|
if len(words) >= 3 {
|
||||||
for _, stopWord := range stopWords {
|
for _, stopWord := range stopWords {
|
||||||
wordPattern := `\b` + stopWord + `\b`
|
wordPattern := `\b` + stopWord + `\b`
|
||||||
re := regexp.MustCompile(wordPattern)
|
re := regexp.MustCompile(wordPattern)
|
||||||
query = re.ReplaceAllString(query, "")
|
query = re.ReplaceAllString(query, "")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
query = strings.TrimSpace(query)
|
query = strings.TrimSpace(query)
|
||||||
if len(query) < 5 {
|
if len(query) < 5 {
|
||||||
return original
|
return original
|
||||||
@@ -340,6 +343,36 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
|||||||
if len(parts) == 0 {
|
if len(parts) == 0 {
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
// Get loaded filenames to filter out filename terms
|
||||||
|
filenames, err := r.storage.ListFiles()
|
||||||
|
if err == nil && len(filenames) > 0 {
|
||||||
|
// Convert to lowercase for case-insensitive matching
|
||||||
|
lowerFilenames := make([]string, len(filenames))
|
||||||
|
for i, f := range filenames {
|
||||||
|
lowerFilenames[i] = strings.ToLower(f)
|
||||||
|
}
|
||||||
|
filteredParts := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
partLower := strings.ToLower(part)
|
||||||
|
skip := false
|
||||||
|
for _, fn := range lowerFilenames {
|
||||||
|
if strings.Contains(fn, partLower) || strings.Contains(partLower, fn) {
|
||||||
|
skip = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !skip {
|
||||||
|
filteredParts = append(filteredParts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If filteredParts not empty and different from original, add filtered query
|
||||||
|
if len(filteredParts) > 0 && len(filteredParts) != len(parts) {
|
||||||
|
filteredQuery := strings.Join(filteredParts, " ")
|
||||||
|
if len(filteredQuery) >= 5 {
|
||||||
|
variations = append(variations, filteredQuery)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
trimmed := strings.Join(parts[:len(parts)-1], " ")
|
trimmed := strings.Join(parts[:len(parts)-1], " ")
|
||||||
if len(trimmed) >= 5 {
|
if len(trimmed) >= 5 {
|
||||||
@@ -403,9 +436,14 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
|||||||
})
|
})
|
||||||
unique := make([]models.VectorRow, 0)
|
unique := make([]models.VectorRow, 0)
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
|
fileCounts := make(map[string]int)
|
||||||
for i := range scored {
|
for i := range scored {
|
||||||
if !seen[scored[i].row.Slug] {
|
if !seen[scored[i].row.Slug] {
|
||||||
|
if fileCounts[scored[i].row.FileName] >= 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
seen[scored[i].row.Slug] = true
|
seen[scored[i].row.Slug] = true
|
||||||
|
fileCounts[scored[i].row.FileName]++
|
||||||
unique = append(unique, scored[i].row)
|
unique = append(unique, scored[i].row)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package rag
|
package rag
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
@@ -221,11 +222,41 @@ func (vs *VectorStorage) SearchKeyword(query string, limit int) ([]models.Vector
|
|||||||
WHERE fts_embeddings MATCH ?
|
WHERE fts_embeddings MATCH ?
|
||||||
ORDER BY score
|
ORDER BY score
|
||||||
LIMIT ?`
|
LIMIT ?`
|
||||||
|
|
||||||
|
// Try original query first
|
||||||
rows, err := vs.sqlxDB.Query(ftsQuery, query, limit)
|
rows, err := vs.sqlxDB.Query(ftsQuery, query, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("FTS search failed: %w", err)
|
return nil, fmt.Errorf("FTS search failed: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
results, err := vs.scanRows(rows)
|
||||||
|
rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no results and query contains multiple terms, try OR fallback
|
||||||
|
if len(results) == 0 && strings.Contains(query, " ") && !strings.Contains(strings.ToUpper(query), " OR ") {
|
||||||
|
// Build OR query: term1 OR term2 OR term3
|
||||||
|
terms := strings.Fields(query)
|
||||||
|
if len(terms) > 1 {
|
||||||
|
orQuery := strings.Join(terms, " OR ")
|
||||||
|
rows, err := vs.sqlxDB.Query(ftsQuery, orQuery, limit)
|
||||||
|
if err != nil {
|
||||||
|
// Return original empty results rather than error
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
orResults, err := vs.scanRows(rows)
|
||||||
|
rows.Close()
|
||||||
|
if err == nil {
|
||||||
|
results = orResults
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanRows converts SQL rows to VectorRow slice
|
||||||
|
func (vs *VectorStorage) scanRows(rows *sql.Rows) ([]models.VectorRow, error) {
|
||||||
var results []models.VectorRow
|
var results []models.VectorRow
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var slug, rawText, fileName string
|
var slug, rawText, fileName string
|
||||||
|
|||||||
Reference in New Issue
Block a user