Enha: rag tuning and tests

This commit is contained in:
Grail Finder
2026-03-08 16:12:32 +03:00
parent e74ff8c03f
commit a1b5f9cdc5
5 changed files with 814 additions and 25 deletions

View File

@@ -74,6 +74,22 @@ func detectPhrases(query string) []string {
return phrases
}
// countPhraseMatches returns the number of query phrases found in text
func countPhraseMatches(text, query string) int {
phrases := detectPhrases(query)
if len(phrases) == 0 {
return 0
}
textLower := strings.ToLower(text)
count := 0
for _, phrase := range phrases {
if strings.Contains(textLower, phrase) {
count++
}
}
return count
}
// parseSlugIndices extracts batch and chunk indices from a slug
// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0")
func parseSlugIndices(slug string) (batch, chunk int, ok bool) {
@@ -120,6 +136,9 @@ func areSlugsAdjacent(slug1, slug2 string) bool {
// Check if they're in sequential batches and chunk indices suggest continuity
// This is heuristic but useful for cross-batch adjacency
if (batch1 == batch2+1 && chunk1 == 0) || (batch2 == batch1+1 && chunk2 == 0) {
return true
}
return false
}
@@ -654,6 +673,10 @@ func (r *RAG) RefineQuery(query string) string {
if len(query) <= 3 {
return original
}
// If query already contains double quotes, assume it's a phrase query and skip refinement
if strings.Contains(query, "\"") {
return original
}
query = strings.ToLower(query)
words := strings.Fields(query)
if len(words) >= 3 {
@@ -799,12 +822,13 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase)
}
}
if quotedQuery != query {
variations = append(variations, quotedQuery)
}
// Disabled malformed quoted query for now
// if quotedQuery != query {
// variations = append(variations, quotedQuery)
// }
// Also add individual phrase variations for short queries
if len(phrases) <= 3 {
if len(phrases) <= 5 {
for _, phrase := range phrases {
// Create a focused query with just this phrase quoted
// Keep original context but emphasize this phrase
@@ -814,6 +838,8 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if focusedQuery != query && focusedQuery != quotedQuery {
variations = append(variations, focusedQuery)
}
// Add the phrase alone (quoted) as a separate variation
variations = append(variations, quotedPhrase)
}
}
}
@@ -822,9 +848,11 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
}
func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
phraseCount := len(detectPhrases(query))
type scoredResult struct {
row models.VectorRow
distance float32
row models.VectorRow
distance float32
phraseMatches int
}
scored := make([]scoredResult, 0, len(results))
for i := range results {
@@ -850,6 +878,14 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
score += 3
}
// Phrase match bonus: extra points for containing detected phrases
phraseMatches := countPhraseMatches(row.RawText, query)
if phraseMatches > 0 {
// Significant bonus per phrase to prioritize exact phrase matches
r.logger.Debug("phrase match bonus", "slug", row.Slug, "phraseMatches", phraseMatches, "score", score)
score += float32(phraseMatches) * 100
}
// Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results,
// boost score to promote narrative continuity
adjacentCount := 0
@@ -866,17 +902,27 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
score += float32(adjacentCount) * 4
}
distance := row.Distance - score/100
scored = append(scored, scoredResult{row: row, distance: distance})
scored = append(scored, scoredResult{row: row, distance: distance, phraseMatches: phraseMatches})
}
sort.Slice(scored, func(i, j int) bool {
return scored[i].distance < scored[j].distance
})
unique := make([]models.VectorRow, 0)
seen := make(map[string]bool)
maxPerFile := 2
if phraseCount > 0 {
maxPerFile = 10
}
fileCounts := make(map[string]int)
for i := range scored {
if !seen[scored[i].row.Slug] {
if fileCounts[scored[i].row.FileName] >= 2 {
// Allow phrase-matching chunks to bypass per-file limit (up to +5 extra)
allowed := fileCounts[scored[i].row.FileName] < maxPerFile
if !allowed && scored[i].phraseMatches > 0 {
// If chunk has phrase matches, allow extra slots (up to maxPerFile + 5)
allowed = fileCounts[scored[i].row.FileName] < maxPerFile+5
}
if !allowed {
continue
}
seen[scored[i].row.Slug] = true
@@ -884,8 +930,8 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
unique = append(unique, scored[i].row)
}
}
if len(unique) > 10 {
unique = unique[:10]
if len(unique) > 30 {
unique = unique[:30]
}
return unique
}
@@ -954,6 +1000,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
r.resetIdleTimer()
refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined)
r.logger.Debug("query variations", "original", query, "refined", refined, "variations", variations)
// Collect embedding search results from all variations
var embResults []models.VectorRow
@@ -985,17 +1032,35 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
return embResults[i].Distance < embResults[j].Distance
})
// Perform keyword search
kwResults, err := r.searchKeyword(refined, limit*2)
if err != nil {
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
kwResults = nil
// Perform keyword search on all variations
var kwResults []models.VectorRow
seenKw := make(map[string]bool)
for _, q := range variations {
results, err := r.searchKeyword(q, limit)
if err != nil {
r.logger.Debug("keyword search failed for variation", "error", err, "query", q)
continue
}
for _, row := range results {
if !seenKw[row.Slug] {
seenKw[row.Slug] = true
kwResults = append(kwResults, row)
}
}
}
// Sort keyword results by distance (already sorted by BM25 score)
// kwResults already sorted by distance (lower is better)
// Sort keyword results by distance (lower is better)
sort.Slice(kwResults, func(i, j int) bool {
return kwResults[i].Distance < kwResults[j].Distance
})
// Combine using Reciprocal Rank Fusion (RRF)
const rrfK = 60
// Use smaller K for phrase-heavy queries to give more weight to top ranks
phraseCount := len(detectPhrases(query))
rrfK := 60.0
if phraseCount > 0 {
rrfK = 30.0
}
r.logger.Debug("RRF parameters", "phraseCount", phraseCount, "rrfK", rrfK, "query", query)
type scoredRow struct {
row models.VectorRow
score float64
@@ -1005,11 +1070,22 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
for rank, row := range embResults {
score := 1.0 / (float64(rank) + rrfK)
scoreMap[row.Slug] += score
if row.Slug == "kjv_bible.epub_1786_0" {
r.logger.Debug("target chunk embedding rank", "rank", rank, "score", score)
}
}
// Add keyword results
// Add keyword results with weight boost when phrases are present
kwWeight := 1.0
if phraseCount > 0 {
kwWeight = 100.0
}
r.logger.Debug("keyword weight", "kwWeight", kwWeight, "phraseCount", phraseCount)
for rank, row := range kwResults {
score := 1.0 / (float64(rank) + rrfK)
score := kwWeight * (1.0 / (float64(rank) + rrfK))
scoreMap[row.Slug] += score
if row.Slug == "kjv_bible.epub_1786_0" {
r.logger.Debug("target chunk keyword rank", "rank", rank, "score", score, "kwWeight", kwWeight, "rrfK", rrfK)
}
// Ensure row exists in combined results
if _, exists := seen[row.Slug]; !exists {
embResults = append(embResults, row)
@@ -1021,6 +1097,18 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
score := scoreMap[row.Slug]
scoredRows = append(scoredRows, scoredRow{row: row, score: score})
}
// Debug: log scores for target chunk and top chunks
if strings.Contains(strings.ToLower(query), "bald") || strings.Contains(strings.ToLower(query), "she bears") {
for _, sr := range scoredRows {
if sr.row.Slug == "kjv_bible.epub_1786_0" {
r.logger.Debug("target chunk score", "slug", sr.row.Slug, "score", sr.score, "distance", sr.row.Distance)
}
}
// Log top 5 scores
for i := 0; i < len(scoredRows) && i < 5; i++ {
r.logger.Debug("top scored row", "rank", i+1, "slug", scoredRows[i].row.Slug, "score", scoredRows[i].score, "distance", scoredRows[i].row.Distance)
}
}
// Sort by descending RRF score
sort.Slice(scoredRows, func(i, j int) bool {
return scoredRows[i].score > scoredRows[j].score
@@ -1099,3 +1187,11 @@ func (r *RAG) Destroy() {
}
}
}
// SetEmbedderForTesting replaces the internal embedder with a mock.
// This function is only available when compiling with the "test" build tag.
func (r *RAG) SetEmbedderForTesting(e Embedder) {
r.mu.Lock()
defer r.mu.Unlock()
r.embedder = e
}