Enha (rag): semantic hybrid search
This commit is contained in:
167
rag/rag.go
167
rag/rag.go
@@ -12,6 +12,7 @@ import (
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -27,8 +28,101 @@ var (
|
||||
FinishedRAGStatus = "finished loading RAG file; press x to exit"
|
||||
LoadedFileRAGStatus = "loaded file"
|
||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||
|
||||
// stopWords are common words that can be removed from queries when not part of phrases
|
||||
stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right", "about", "like", "such", "than", "then", "also", "too"}
|
||||
)
|
||||
|
||||
// isStopWord checks if a word is in the stop words list
|
||||
func isStopWord(word string) bool {
|
||||
for _, stop := range stopWords {
|
||||
if strings.EqualFold(word, stop) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectPhrases returns multi-word phrases from a query that should be treated as units
|
||||
func detectPhrases(query string) []string {
|
||||
words := strings.Fields(strings.ToLower(query))
|
||||
var phrases []string
|
||||
|
||||
for i := 0; i < len(words)-1; i++ {
|
||||
word1 := strings.Trim(words[i], ".,!?;:'\"()[]{}")
|
||||
word2 := strings.Trim(words[i+1], ".,!?;:'\"()[]{}")
|
||||
|
||||
// Skip if either word is a stop word or too short
|
||||
if isStopWord(word1) || isStopWord(word2) || len(word1) < 2 || len(word2) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this pair appears to be a meaningful phrase
|
||||
// Simple heuristic: consecutive non-stop words of reasonable length
|
||||
phrase := word1 + " " + word2
|
||||
phrases = append(phrases, phrase)
|
||||
|
||||
// Optionally check for 3-word phrases
|
||||
if i < len(words)-2 {
|
||||
word3 := strings.Trim(words[i+2], ".,!?;:'\"()[]{}")
|
||||
if !isStopWord(word3) && len(word3) >= 2 {
|
||||
phrases = append(phrases, word1+" "+word2+" "+word3)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return phrases
|
||||
}
|
||||
|
||||
// parseSlugIndices extracts batch and chunk indices from a slug
|
||||
// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0")
|
||||
func parseSlugIndices(slug string) (batch, chunk int, ok bool) {
|
||||
// Find the last two numbers separated by underscores
|
||||
re := regexp.MustCompile(`_(\d+)_(\d+)$`)
|
||||
matches := re.FindStringSubmatch(slug)
|
||||
if matches == nil || len(matches) != 3 {
|
||||
return 0, 0, false
|
||||
}
|
||||
batch, err1 := strconv.Atoi(matches[1])
|
||||
chunk, err2 := strconv.Atoi(matches[2])
|
||||
if err1 != nil || err2 != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
return batch, chunk, true
|
||||
}
|
||||
|
||||
// areSlugsAdjacent returns true if two slugs are from the same file and have sequential indices
|
||||
func areSlugsAdjacent(slug1, slug2 string) bool {
|
||||
// Extract filename prefix (everything before the last underscore sequence)
|
||||
parts1 := strings.Split(slug1, "_")
|
||||
parts2 := strings.Split(slug2, "_")
|
||||
if len(parts1) < 3 || len(parts2) < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Compare filename prefixes (all parts except last two)
|
||||
prefix1 := strings.Join(parts1[:len(parts1)-2], "_")
|
||||
prefix2 := strings.Join(parts2[:len(parts2)-2], "_")
|
||||
if prefix1 != prefix2 {
|
||||
return false
|
||||
}
|
||||
|
||||
batch1, chunk1, ok1 := parseSlugIndices(slug1)
|
||||
batch2, chunk2, ok2 := parseSlugIndices(slug2)
|
||||
if !ok1 || !ok2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if they're in same batch and chunks are sequential
|
||||
if batch1 == batch2 && (chunk1 == chunk2+1 || chunk2 == chunk1+1) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if they're in sequential batches and chunk indices suggest continuity
|
||||
// This is heuristic but useful for cross-batch adjacency
|
||||
return false
|
||||
}
|
||||
|
||||
type RAG struct {
|
||||
logger *slog.Logger
|
||||
store storage.FullRepo
|
||||
@@ -155,8 +249,8 @@ func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
|
||||
}
|
||||
|
||||
func sanitizeFTSQuery(query string) string {
|
||||
// Remove double quotes and other problematic characters for FTS5
|
||||
// query = strings.ReplaceAll(query, "\"", " ")
|
||||
// Keep double quotes for FTS5 phrase matching
|
||||
// Remove other problematic characters
|
||||
query = strings.ReplaceAll(query, "'", " ")
|
||||
query = strings.ReplaceAll(query, ";", " ")
|
||||
query = strings.ReplaceAll(query, "\\", " ")
|
||||
@@ -549,7 +643,6 @@ func (r *RAG) RemoveFile(filename string) error {
|
||||
var (
|
||||
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
|
||||
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
|
||||
stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
|
||||
)
|
||||
|
||||
func (r *RAG) RefineQuery(query string) string {
|
||||
@@ -564,7 +657,20 @@ func (r *RAG) RefineQuery(query string) string {
|
||||
query = strings.ToLower(query)
|
||||
words := strings.Fields(query)
|
||||
if len(words) >= 3 {
|
||||
// Detect phrases and protect words that are part of phrases
|
||||
phrases := detectPhrases(query)
|
||||
protectedWords := make(map[string]bool)
|
||||
for _, phrase := range phrases {
|
||||
for _, word := range strings.Fields(phrase) {
|
||||
protectedWords[word] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Remove stop words that are not protected
|
||||
for _, stopWord := range stopWords {
|
||||
if protectedWords[stopWord] {
|
||||
continue
|
||||
}
|
||||
wordPattern := `\b` + stopWord + `\b`
|
||||
re := regexp.MustCompile(wordPattern)
|
||||
query = re.ReplaceAllString(query, "")
|
||||
@@ -673,6 +779,45 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
||||
if !strings.HasSuffix(query, " summary") {
|
||||
variations = append(variations, query+" summary")
|
||||
}
|
||||
|
||||
// Add phrase-quoted variations for better FTS5 matching
|
||||
phrases := detectPhrases(query)
|
||||
if len(phrases) > 0 {
|
||||
// Sort phrases by length descending to prioritize longer phrases
|
||||
sort.Slice(phrases, func(i, j int) bool {
|
||||
return len(phrases[i]) > len(phrases[j])
|
||||
})
|
||||
|
||||
// Create a version with all phrases quoted
|
||||
quotedQuery := query
|
||||
for _, phrase := range phrases {
|
||||
// Only quote if not already quoted
|
||||
quotedPhrase := "\"" + phrase + "\""
|
||||
if !strings.Contains(strings.ToLower(quotedQuery), strings.ToLower(quotedPhrase)) {
|
||||
// Case-insensitive replacement of phrase with quoted version
|
||||
re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
|
||||
quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase)
|
||||
}
|
||||
}
|
||||
if quotedQuery != query {
|
||||
variations = append(variations, quotedQuery)
|
||||
}
|
||||
|
||||
// Also add individual phrase variations for short queries
|
||||
if len(phrases) <= 3 {
|
||||
for _, phrase := range phrases {
|
||||
// Create a focused query with just this phrase quoted
|
||||
// Keep original context but emphasize this phrase
|
||||
quotedPhrase := "\"" + phrase + "\""
|
||||
re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
|
||||
focusedQuery := re.ReplaceAllString(query, quotedPhrase)
|
||||
if focusedQuery != query && focusedQuery != quotedQuery {
|
||||
variations = append(variations, focusedQuery)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return variations
|
||||
}
|
||||
|
||||
@@ -704,6 +849,22 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
||||
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
||||
score += 3
|
||||
}
|
||||
|
||||
// Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results,
|
||||
// boost score to promote narrative continuity
|
||||
adjacentCount := 0
|
||||
for _, other := range results {
|
||||
if other.Slug == row.Slug {
|
||||
continue
|
||||
}
|
||||
if areSlugsAdjacent(row.Slug, other.Slug) {
|
||||
adjacentCount++
|
||||
}
|
||||
}
|
||||
if adjacentCount > 0 {
|
||||
// Bonus per adjacent chunk, but diminishing returns
|
||||
score += float32(adjacentCount) * 4
|
||||
}
|
||||
distance := row.Distance - score/100
|
||||
scored = append(scored, scoredResult{row: row, distance: distance})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user