Enha: rag tuning and tests
This commit is contained in:
136
rag/rag.go
136
rag/rag.go
@@ -74,6 +74,22 @@ func detectPhrases(query string) []string {
|
||||
return phrases
|
||||
}
|
||||
|
||||
// countPhraseMatches returns the number of query phrases found in text
|
||||
func countPhraseMatches(text, query string) int {
|
||||
phrases := detectPhrases(query)
|
||||
if len(phrases) == 0 {
|
||||
return 0
|
||||
}
|
||||
textLower := strings.ToLower(text)
|
||||
count := 0
|
||||
for _, phrase := range phrases {
|
||||
if strings.Contains(textLower, phrase) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// parseSlugIndices extracts batch and chunk indices from a slug
|
||||
// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0")
|
||||
func parseSlugIndices(slug string) (batch, chunk int, ok bool) {
|
||||
@@ -120,6 +136,9 @@ func areSlugsAdjacent(slug1, slug2 string) bool {
|
||||
|
||||
// Check if they're in sequential batches and chunk indices suggest continuity
|
||||
// This is heuristic but useful for cross-batch adjacency
|
||||
if (batch1 == batch2+1 && chunk1 == 0) || (batch2 == batch1+1 && chunk2 == 0) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -654,6 +673,10 @@ func (r *RAG) RefineQuery(query string) string {
|
||||
if len(query) <= 3 {
|
||||
return original
|
||||
}
|
||||
// If query already contains double quotes, assume it's a phrase query and skip refinement
|
||||
if strings.Contains(query, "\"") {
|
||||
return original
|
||||
}
|
||||
query = strings.ToLower(query)
|
||||
words := strings.Fields(query)
|
||||
if len(words) >= 3 {
|
||||
@@ -799,12 +822,13 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
||||
quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase)
|
||||
}
|
||||
}
|
||||
if quotedQuery != query {
|
||||
variations = append(variations, quotedQuery)
|
||||
}
|
||||
// Disabled malformed quoted query for now
|
||||
// if quotedQuery != query {
|
||||
// variations = append(variations, quotedQuery)
|
||||
// }
|
||||
|
||||
// Also add individual phrase variations for short queries
|
||||
if len(phrases) <= 3 {
|
||||
if len(phrases) <= 5 {
|
||||
for _, phrase := range phrases {
|
||||
// Create a focused query with just this phrase quoted
|
||||
// Keep original context but emphasize this phrase
|
||||
@@ -814,6 +838,8 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
||||
if focusedQuery != query && focusedQuery != quotedQuery {
|
||||
variations = append(variations, focusedQuery)
|
||||
}
|
||||
// Add the phrase alone (quoted) as a separate variation
|
||||
variations = append(variations, quotedPhrase)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -822,9 +848,11 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
||||
}
|
||||
|
||||
func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
|
||||
phraseCount := len(detectPhrases(query))
|
||||
type scoredResult struct {
|
||||
row models.VectorRow
|
||||
distance float32
|
||||
row models.VectorRow
|
||||
distance float32
|
||||
phraseMatches int
|
||||
}
|
||||
scored := make([]scoredResult, 0, len(results))
|
||||
for i := range results {
|
||||
@@ -850,6 +878,14 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
||||
score += 3
|
||||
}
|
||||
|
||||
// Phrase match bonus: extra points for containing detected phrases
|
||||
phraseMatches := countPhraseMatches(row.RawText, query)
|
||||
if phraseMatches > 0 {
|
||||
// Significant bonus per phrase to prioritize exact phrase matches
|
||||
r.logger.Debug("phrase match bonus", "slug", row.Slug, "phraseMatches", phraseMatches, "score", score)
|
||||
score += float32(phraseMatches) * 100
|
||||
}
|
||||
|
||||
// Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results,
|
||||
// boost score to promote narrative continuity
|
||||
adjacentCount := 0
|
||||
@@ -866,17 +902,27 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
||||
score += float32(adjacentCount) * 4
|
||||
}
|
||||
distance := row.Distance - score/100
|
||||
scored = append(scored, scoredResult{row: row, distance: distance})
|
||||
scored = append(scored, scoredResult{row: row, distance: distance, phraseMatches: phraseMatches})
|
||||
}
|
||||
sort.Slice(scored, func(i, j int) bool {
|
||||
return scored[i].distance < scored[j].distance
|
||||
})
|
||||
unique := make([]models.VectorRow, 0)
|
||||
seen := make(map[string]bool)
|
||||
maxPerFile := 2
|
||||
if phraseCount > 0 {
|
||||
maxPerFile = 10
|
||||
}
|
||||
fileCounts := make(map[string]int)
|
||||
for i := range scored {
|
||||
if !seen[scored[i].row.Slug] {
|
||||
if fileCounts[scored[i].row.FileName] >= 2 {
|
||||
// Allow phrase-matching chunks to bypass per-file limit (up to +5 extra)
|
||||
allowed := fileCounts[scored[i].row.FileName] < maxPerFile
|
||||
if !allowed && scored[i].phraseMatches > 0 {
|
||||
// If chunk has phrase matches, allow extra slots (up to maxPerFile + 5)
|
||||
allowed = fileCounts[scored[i].row.FileName] < maxPerFile+5
|
||||
}
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
seen[scored[i].row.Slug] = true
|
||||
@@ -884,8 +930,8 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
||||
unique = append(unique, scored[i].row)
|
||||
}
|
||||
}
|
||||
if len(unique) > 10 {
|
||||
unique = unique[:10]
|
||||
if len(unique) > 30 {
|
||||
unique = unique[:30]
|
||||
}
|
||||
return unique
|
||||
}
|
||||
@@ -954,6 +1000,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
r.resetIdleTimer()
|
||||
refined := r.RefineQuery(query)
|
||||
variations := r.GenerateQueryVariations(refined)
|
||||
r.logger.Debug("query variations", "original", query, "refined", refined, "variations", variations)
|
||||
|
||||
// Collect embedding search results from all variations
|
||||
var embResults []models.VectorRow
|
||||
@@ -985,17 +1032,35 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
return embResults[i].Distance < embResults[j].Distance
|
||||
})
|
||||
|
||||
// Perform keyword search
|
||||
kwResults, err := r.searchKeyword(refined, limit*2)
|
||||
if err != nil {
|
||||
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
|
||||
kwResults = nil
|
||||
// Perform keyword search on all variations
|
||||
var kwResults []models.VectorRow
|
||||
seenKw := make(map[string]bool)
|
||||
for _, q := range variations {
|
||||
results, err := r.searchKeyword(q, limit)
|
||||
if err != nil {
|
||||
r.logger.Debug("keyword search failed for variation", "error", err, "query", q)
|
||||
continue
|
||||
}
|
||||
for _, row := range results {
|
||||
if !seenKw[row.Slug] {
|
||||
seenKw[row.Slug] = true
|
||||
kwResults = append(kwResults, row)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort keyword results by distance (already sorted by BM25 score)
|
||||
// kwResults already sorted by distance (lower is better)
|
||||
// Sort keyword results by distance (lower is better)
|
||||
sort.Slice(kwResults, func(i, j int) bool {
|
||||
return kwResults[i].Distance < kwResults[j].Distance
|
||||
})
|
||||
|
||||
// Combine using Reciprocal Rank Fusion (RRF)
|
||||
const rrfK = 60
|
||||
// Use smaller K for phrase-heavy queries to give more weight to top ranks
|
||||
phraseCount := len(detectPhrases(query))
|
||||
rrfK := 60.0
|
||||
if phraseCount > 0 {
|
||||
rrfK = 30.0
|
||||
}
|
||||
r.logger.Debug("RRF parameters", "phraseCount", phraseCount, "rrfK", rrfK, "query", query)
|
||||
type scoredRow struct {
|
||||
row models.VectorRow
|
||||
score float64
|
||||
@@ -1005,11 +1070,22 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
for rank, row := range embResults {
|
||||
score := 1.0 / (float64(rank) + rrfK)
|
||||
scoreMap[row.Slug] += score
|
||||
if row.Slug == "kjv_bible.epub_1786_0" {
|
||||
r.logger.Debug("target chunk embedding rank", "rank", rank, "score", score)
|
||||
}
|
||||
}
|
||||
// Add keyword results
|
||||
// Add keyword results with weight boost when phrases are present
|
||||
kwWeight := 1.0
|
||||
if phraseCount > 0 {
|
||||
kwWeight = 100.0
|
||||
}
|
||||
r.logger.Debug("keyword weight", "kwWeight", kwWeight, "phraseCount", phraseCount)
|
||||
for rank, row := range kwResults {
|
||||
score := 1.0 / (float64(rank) + rrfK)
|
||||
score := kwWeight * (1.0 / (float64(rank) + rrfK))
|
||||
scoreMap[row.Slug] += score
|
||||
if row.Slug == "kjv_bible.epub_1786_0" {
|
||||
r.logger.Debug("target chunk keyword rank", "rank", rank, "score", score, "kwWeight", kwWeight, "rrfK", rrfK)
|
||||
}
|
||||
// Ensure row exists in combined results
|
||||
if _, exists := seen[row.Slug]; !exists {
|
||||
embResults = append(embResults, row)
|
||||
@@ -1021,6 +1097,18 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
score := scoreMap[row.Slug]
|
||||
scoredRows = append(scoredRows, scoredRow{row: row, score: score})
|
||||
}
|
||||
// Debug: log scores for target chunk and top chunks
|
||||
if strings.Contains(strings.ToLower(query), "bald") || strings.Contains(strings.ToLower(query), "she bears") {
|
||||
for _, sr := range scoredRows {
|
||||
if sr.row.Slug == "kjv_bible.epub_1786_0" {
|
||||
r.logger.Debug("target chunk score", "slug", sr.row.Slug, "score", sr.score, "distance", sr.row.Distance)
|
||||
}
|
||||
}
|
||||
// Log top 5 scores
|
||||
for i := 0; i < len(scoredRows) && i < 5; i++ {
|
||||
r.logger.Debug("top scored row", "rank", i+1, "slug", scoredRows[i].row.Slug, "score", scoredRows[i].score, "distance", scoredRows[i].row.Distance)
|
||||
}
|
||||
}
|
||||
// Sort by descending RRF score
|
||||
sort.Slice(scoredRows, func(i, j int) bool {
|
||||
return scoredRows[i].score > scoredRows[j].score
|
||||
@@ -1099,3 +1187,11 @@ func (r *RAG) Destroy() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetEmbedderForTesting replaces the internal embedder with a mock.
|
||||
// This function is only available when compiling with the "test" build tag.
|
||||
func (r *RAG) SetEmbedderForTesting(e Embedder) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.embedder = e
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user