Enha (rag): semantic hybrid search
This commit is contained in:
167
rag/rag.go
167
rag/rag.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -27,8 +28,101 @@ var (
|
|||||||
FinishedRAGStatus = "finished loading RAG file; press x to exit"
|
FinishedRAGStatus = "finished loading RAG file; press x to exit"
|
||||||
LoadedFileRAGStatus = "loaded file"
|
LoadedFileRAGStatus = "loaded file"
|
||||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||||
|
|
||||||
|
// stopWords are common words that can be removed from queries when not part of phrases
|
||||||
|
stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right", "about", "like", "such", "than", "then", "also", "too"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isStopWord checks if a word is in the stop words list
|
||||||
|
func isStopWord(word string) bool {
|
||||||
|
for _, stop := range stopWords {
|
||||||
|
if strings.EqualFold(word, stop) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectPhrases returns multi-word phrases from a query that should be treated as units
|
||||||
|
func detectPhrases(query string) []string {
|
||||||
|
words := strings.Fields(strings.ToLower(query))
|
||||||
|
var phrases []string
|
||||||
|
|
||||||
|
for i := 0; i < len(words)-1; i++ {
|
||||||
|
word1 := strings.Trim(words[i], ".,!?;:'\"()[]{}")
|
||||||
|
word2 := strings.Trim(words[i+1], ".,!?;:'\"()[]{}")
|
||||||
|
|
||||||
|
// Skip if either word is a stop word or too short
|
||||||
|
if isStopWord(word1) || isStopWord(word2) || len(word1) < 2 || len(word2) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this pair appears to be a meaningful phrase
|
||||||
|
// Simple heuristic: consecutive non-stop words of reasonable length
|
||||||
|
phrase := word1 + " " + word2
|
||||||
|
phrases = append(phrases, phrase)
|
||||||
|
|
||||||
|
// Optionally check for 3-word phrases
|
||||||
|
if i < len(words)-2 {
|
||||||
|
word3 := strings.Trim(words[i+2], ".,!?;:'\"()[]{}")
|
||||||
|
if !isStopWord(word3) && len(word3) >= 2 {
|
||||||
|
phrases = append(phrases, word1+" "+word2+" "+word3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return phrases
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSlugIndices extracts batch and chunk indices from a slug
|
||||||
|
// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0")
|
||||||
|
func parseSlugIndices(slug string) (batch, chunk int, ok bool) {
|
||||||
|
// Find the last two numbers separated by underscores
|
||||||
|
re := regexp.MustCompile(`_(\d+)_(\d+)$`)
|
||||||
|
matches := re.FindStringSubmatch(slug)
|
||||||
|
if matches == nil || len(matches) != 3 {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
batch, err1 := strconv.Atoi(matches[1])
|
||||||
|
chunk, err2 := strconv.Atoi(matches[2])
|
||||||
|
if err1 != nil || err2 != nil {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
return batch, chunk, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// areSlugsAdjacent returns true if two slugs are from the same file and have sequential indices
|
||||||
|
func areSlugsAdjacent(slug1, slug2 string) bool {
|
||||||
|
// Extract filename prefix (everything before the last underscore sequence)
|
||||||
|
parts1 := strings.Split(slug1, "_")
|
||||||
|
parts2 := strings.Split(slug2, "_")
|
||||||
|
if len(parts1) < 3 || len(parts2) < 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare filename prefixes (all parts except last two)
|
||||||
|
prefix1 := strings.Join(parts1[:len(parts1)-2], "_")
|
||||||
|
prefix2 := strings.Join(parts2[:len(parts2)-2], "_")
|
||||||
|
if prefix1 != prefix2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
batch1, chunk1, ok1 := parseSlugIndices(slug1)
|
||||||
|
batch2, chunk2, ok2 := parseSlugIndices(slug2)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if they're in same batch and chunks are sequential
|
||||||
|
if batch1 == batch2 && (chunk1 == chunk2+1 || chunk2 == chunk1+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if they're in sequential batches and chunk indices suggest continuity
|
||||||
|
// This is heuristic but useful for cross-batch adjacency
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type RAG struct {
|
type RAG struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
store storage.FullRepo
|
store storage.FullRepo
|
||||||
@@ -155,8 +249,8 @@ func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeFTSQuery(query string) string {
|
func sanitizeFTSQuery(query string) string {
|
||||||
// Remove double quotes and other problematic characters for FTS5
|
// Keep double quotes for FTS5 phrase matching
|
||||||
// query = strings.ReplaceAll(query, "\"", " ")
|
// Remove other problematic characters
|
||||||
query = strings.ReplaceAll(query, "'", " ")
|
query = strings.ReplaceAll(query, "'", " ")
|
||||||
query = strings.ReplaceAll(query, ";", " ")
|
query = strings.ReplaceAll(query, ";", " ")
|
||||||
query = strings.ReplaceAll(query, "\\", " ")
|
query = strings.ReplaceAll(query, "\\", " ")
|
||||||
@@ -549,7 +643,6 @@ func (r *RAG) RemoveFile(filename string) error {
|
|||||||
var (
|
var (
|
||||||
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
|
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
|
||||||
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
|
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
|
||||||
stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *RAG) RefineQuery(query string) string {
|
func (r *RAG) RefineQuery(query string) string {
|
||||||
@@ -564,7 +657,20 @@ func (r *RAG) RefineQuery(query string) string {
|
|||||||
query = strings.ToLower(query)
|
query = strings.ToLower(query)
|
||||||
words := strings.Fields(query)
|
words := strings.Fields(query)
|
||||||
if len(words) >= 3 {
|
if len(words) >= 3 {
|
||||||
|
// Detect phrases and protect words that are part of phrases
|
||||||
|
phrases := detectPhrases(query)
|
||||||
|
protectedWords := make(map[string]bool)
|
||||||
|
for _, phrase := range phrases {
|
||||||
|
for _, word := range strings.Fields(phrase) {
|
||||||
|
protectedWords[word] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove stop words that are not protected
|
||||||
for _, stopWord := range stopWords {
|
for _, stopWord := range stopWords {
|
||||||
|
if protectedWords[stopWord] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
wordPattern := `\b` + stopWord + `\b`
|
wordPattern := `\b` + stopWord + `\b`
|
||||||
re := regexp.MustCompile(wordPattern)
|
re := regexp.MustCompile(wordPattern)
|
||||||
query = re.ReplaceAllString(query, "")
|
query = re.ReplaceAllString(query, "")
|
||||||
@@ -673,6 +779,45 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
|||||||
if !strings.HasSuffix(query, " summary") {
|
if !strings.HasSuffix(query, " summary") {
|
||||||
variations = append(variations, query+" summary")
|
variations = append(variations, query+" summary")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add phrase-quoted variations for better FTS5 matching
|
||||||
|
phrases := detectPhrases(query)
|
||||||
|
if len(phrases) > 0 {
|
||||||
|
// Sort phrases by length descending to prioritize longer phrases
|
||||||
|
sort.Slice(phrases, func(i, j int) bool {
|
||||||
|
return len(phrases[i]) > len(phrases[j])
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a version with all phrases quoted
|
||||||
|
quotedQuery := query
|
||||||
|
for _, phrase := range phrases {
|
||||||
|
// Only quote if not already quoted
|
||||||
|
quotedPhrase := "\"" + phrase + "\""
|
||||||
|
if !strings.Contains(strings.ToLower(quotedQuery), strings.ToLower(quotedPhrase)) {
|
||||||
|
// Case-insensitive replacement of phrase with quoted version
|
||||||
|
re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
|
||||||
|
quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if quotedQuery != query {
|
||||||
|
variations = append(variations, quotedQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also add individual phrase variations for short queries
|
||||||
|
if len(phrases) <= 3 {
|
||||||
|
for _, phrase := range phrases {
|
||||||
|
// Create a focused query with just this phrase quoted
|
||||||
|
// Keep original context but emphasize this phrase
|
||||||
|
quotedPhrase := "\"" + phrase + "\""
|
||||||
|
re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
|
||||||
|
focusedQuery := re.ReplaceAllString(query, quotedPhrase)
|
||||||
|
if focusedQuery != query && focusedQuery != quotedQuery {
|
||||||
|
variations = append(variations, focusedQuery)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -704,6 +849,22 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
|||||||
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
||||||
score += 3
|
score += 3
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results,
|
||||||
|
// boost score to promote narrative continuity
|
||||||
|
adjacentCount := 0
|
||||||
|
for _, other := range results {
|
||||||
|
if other.Slug == row.Slug {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if areSlugsAdjacent(row.Slug, other.Slug) {
|
||||||
|
adjacentCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if adjacentCount > 0 {
|
||||||
|
// Bonus per adjacent chunk, but diminishing returns
|
||||||
|
score += float32(adjacentCount) * 4
|
||||||
|
}
|
||||||
distance := row.Distance - score/100
|
distance := row.Distance - score/100
|
||||||
scored = append(scored, scoredResult{row: row, distance: distance})
|
scored = append(scored, scoredResult{row: row, distance: distance})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user