Files
gf-lt/rag/rag.go
2026-03-08 16:12:32 +03:00

1198 lines
35 KiB
Go

package rag
import (
"context"
"errors"
"fmt"
"gf-lt/config"
"gf-lt/models"
"gf-lt/storage"
"log/slog"
"path"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/neurosnap/sentences/english"
)
const ()
var (
// Status messages for TUI integration
LongJobStatusCh = make(chan string, 100) // Increased buffer size for parallel batch updates
FinishedRAGStatus = "finished loading RAG file; press x to exit"
LoadedFileRAGStatus = "loaded file"
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
// stopWords are common words that can be removed from queries when not part of phrases
stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right", "about", "like", "such", "than", "then", "also", "too"}
)
// isStopWord checks if a word is in the stop words list
func isStopWord(word string) bool {
for _, stop := range stopWords {
if strings.EqualFold(word, stop) {
return true
}
}
return false
}
// detectPhrases returns multi-word phrases from a query that should be treated as units
func detectPhrases(query string) []string {
words := strings.Fields(strings.ToLower(query))
var phrases []string
for i := 0; i < len(words)-1; i++ {
word1 := strings.Trim(words[i], ".,!?;:'\"()[]{}")
word2 := strings.Trim(words[i+1], ".,!?;:'\"()[]{}")
// Skip if either word is a stop word or too short
if isStopWord(word1) || isStopWord(word2) || len(word1) < 2 || len(word2) < 2 {
continue
}
// Check if this pair appears to be a meaningful phrase
// Simple heuristic: consecutive non-stop words of reasonable length
phrase := word1 + " " + word2
phrases = append(phrases, phrase)
// Optionally check for 3-word phrases
if i < len(words)-2 {
word3 := strings.Trim(words[i+2], ".,!?;:'\"()[]{}")
if !isStopWord(word3) && len(word3) >= 2 {
phrases = append(phrases, word1+" "+word2+" "+word3)
}
}
}
return phrases
}
// countPhraseMatches returns the number of query phrases found in text
func countPhraseMatches(text, query string) int {
phrases := detectPhrases(query)
if len(phrases) == 0 {
return 0
}
textLower := strings.ToLower(text)
count := 0
for _, phrase := range phrases {
if strings.Contains(textLower, phrase) {
count++
}
}
return count
}
// parseSlugIndices extracts batch and chunk indices from a slug
// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0")
func parseSlugIndices(slug string) (batch, chunk int, ok bool) {
// Find the last two numbers separated by underscores
re := regexp.MustCompile(`_(\d+)_(\d+)$`)
matches := re.FindStringSubmatch(slug)
if matches == nil || len(matches) != 3 {
return 0, 0, false
}
batch, err1 := strconv.Atoi(matches[1])
chunk, err2 := strconv.Atoi(matches[2])
if err1 != nil || err2 != nil {
return 0, 0, false
}
return batch, chunk, true
}
// areSlugsAdjacent returns true if two slugs are from the same file and have sequential indices
func areSlugsAdjacent(slug1, slug2 string) bool {
// Extract filename prefix (everything before the last underscore sequence)
parts1 := strings.Split(slug1, "_")
parts2 := strings.Split(slug2, "_")
if len(parts1) < 3 || len(parts2) < 3 {
return false
}
// Compare filename prefixes (all parts except last two)
prefix1 := strings.Join(parts1[:len(parts1)-2], "_")
prefix2 := strings.Join(parts2[:len(parts2)-2], "_")
if prefix1 != prefix2 {
return false
}
batch1, chunk1, ok1 := parseSlugIndices(slug1)
batch2, chunk2, ok2 := parseSlugIndices(slug2)
if !ok1 || !ok2 {
return false
}
// Check if they're in same batch and chunks are sequential
if batch1 == batch2 && (chunk1 == chunk2+1 || chunk2 == chunk1+1) {
return true
}
// Check if they're in sequential batches and chunk indices suggest continuity
// This is heuristic but useful for cross-batch adjacency
if (batch1 == batch2+1 && chunk1 == 0) || (batch2 == batch1+1 && chunk2 == 0) {
return true
}
return false
}
type RAG struct {
logger *slog.Logger
store storage.FullRepo
cfg *config.Config
embedder Embedder
storage *VectorStorage
mu sync.RWMutex
idleMu sync.Mutex
fallbackMsg string
idleTimer *time.Timer
idleTimeout time.Duration
}
// batchTask represents a single batch to be embedded
type batchTask struct {
batchIndex int
paragraphs []string
filename string
totalBatches int
}
// batchResult represents the result of embedding a batch
type batchResult struct {
batchIndex int
embeddings [][]float32
paragraphs []string
filename string
}
// sendStatusNonBlocking sends a status message without blocking
func (r *RAG) sendStatusNonBlocking(status string) {
select {
case LongJobStatusCh <- status:
default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", status)
}
}
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
var embedder Embedder
var fallbackMsg string
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
if err != nil {
l.Error("failed to create ONNX embedder, falling back to API", "error", err)
fallbackMsg = err.Error()
embedder = NewAPIEmbedder(l, cfg)
} else {
embedder = emb
l.Info("using ONNX embedder", "model", cfg.EmbedModelPath, "dims", cfg.EmbedDims)
}
} else {
embedder = NewAPIEmbedder(l, cfg)
l.Info("using API embedder", "url", cfg.EmbedURL)
}
rag := &RAG{
logger: l,
store: s,
cfg: cfg,
embedder: embedder,
storage: NewVectorStorage(l, s),
fallbackMsg: fallbackMsg,
idleTimeout: 30 * time.Second,
}
// Note: Vector tables are created via database migrations, not at runtime
return rag, nil
}
func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
if len(sentences) == 0 {
return nil
}
if overlapWords >= wordLimit {
overlapWords = wordLimit / 2
}
var chunks []string
i := 0
for i < len(sentences) {
var chunkWords []string
wordCount := 0
j := i
for j < len(sentences) && wordCount <= int(wordLimit) {
sentence := sentences[j]
words := strings.Fields(sentence)
chunkWords = append(chunkWords, sentence)
wordCount += len(words)
j++
// If this sentence alone exceeds limit, still include it and stop
if wordCount > int(wordLimit) {
break
}
}
if len(chunkWords) == 0 {
break
}
chunk := strings.Join(chunkWords, " ")
chunks = append(chunks, chunk)
if j >= len(sentences) {
break
}
// Move i forward by skipping overlap
if overlapWords == 0 {
i = j
continue
}
// Calculate how many sentences to skip to achieve overlapWords
overlapRemaining := int(overlapWords)
newI := i
for newI < j && overlapRemaining > 0 {
words := len(strings.Fields(sentences[newI]))
overlapRemaining -= words
if overlapRemaining >= 0 {
newI++
}
}
if newI == i {
newI = j
}
i = newI
}
return chunks
}
func sanitizeFTSQuery(query string) string {
// Keep double quotes for FTS5 phrase matching
// Remove other problematic characters
query = strings.ReplaceAll(query, "'", " ")
query = strings.ReplaceAll(query, ";", " ")
query = strings.ReplaceAll(query, "\\", " ")
query = strings.TrimSpace(query)
if query == "" {
return "*" // match all
}
return query
}
func (r *RAG) LoadRAG(fpath string) error {
return r.LoadRAGWithContext(context.Background(), fpath)
}
func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
r.mu.Lock()
defer r.mu.Unlock()
fileText, err := ExtractText(fpath)
if err != nil {
return err
}
r.logger.Debug("rag: loaded file", "fp", fpath)
// Send initial status (non-blocking with retry)
r.sendStatusNonBlocking(LoadedFileRAGStatus)
tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil {
return err
}
sentences := tokenizer.Tokenize(fileText)
sents := make([]string, len(sentences))
for i, s := range sentences {
sents[i] = s.Text
}
// Create chunks with overlap
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
// Adjust batch size if needed
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
r.cfg.RAGBatchSize = len(paragraphs)
}
if len(paragraphs) == 0 {
return errors.New("no valid paragraphs found in file")
}
totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize
r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize)
// Determine concurrency level
concurrency := runtime.NumCPU()
if concurrency > totalBatches {
concurrency = totalBatches
}
if concurrency < 1 {
concurrency = 1
}
// If using ONNX embedder, limit concurrency to 1 due to mutex serialization
var isONNX bool
if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX {
concurrency = 1
}
embedderType := "API"
if isONNX {
embedderType = "ONNX"
}
r.logger.Debug("parallel embedding setup",
"total_batches", totalBatches,
"concurrency", concurrency,
"embedder", embedderType,
"batch_size", r.cfg.RAGBatchSize)
// Create context with timeout (30 minutes) and cancellation for error handling
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()
// Channels for task distribution and results
taskCh := make(chan batchTask, totalBatches)
resultCh := make(chan batchResult, totalBatches)
errorCh := make(chan error, totalBatches)
// Start worker goroutines
var wg sync.WaitGroup
for w := 0; w < concurrency; w++ {
wg.Add(1)
go r.embeddingWorker(ctx, w, taskCh, resultCh, errorCh, &wg)
}
// Close task channel after all tasks are sent (by separate goroutine)
go func() {
// Ensure task channel is closed when this goroutine exits
defer close(taskCh)
r.logger.Debug("task distributor started", "total_batches", totalBatches)
for i := 0; i < totalBatches; i++ {
start := i * r.cfg.RAGBatchSize
end := start + r.cfg.RAGBatchSize
if end > len(paragraphs) {
end = len(paragraphs)
}
batch := paragraphs[start:end]
// Filter empty paragraphs
nonEmptyBatch := make([]string, 0, len(batch))
for _, p := range batch {
if strings.TrimSpace(p) != "" {
nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
}
}
task := batchTask{
batchIndex: i,
paragraphs: nonEmptyBatch,
filename: path.Base(fpath),
totalBatches: totalBatches,
}
select {
case taskCh <- task:
r.logger.Debug("task distributor sent batch", "batch", i, "paragraphs", len(nonEmptyBatch))
case <-ctx.Done():
r.logger.Debug("task distributor cancelled", "batches_sent", i+1, "total_batches", totalBatches)
return
}
}
r.logger.Debug("task distributor finished", "batches_sent", totalBatches)
}()
// Wait for workers to finish and close result channel
go func() {
wg.Wait()
close(resultCh)
}()
// Process results in order and write to database
nextExpectedBatch := 0
resultsBuffer := make(map[int]batchResult)
filename := path.Base(fpath)
batchesProcessed := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errorCh:
// First error from any worker, cancel everything
cancel()
r.logger.Error("embedding worker failed", "error", err)
r.sendStatusNonBlocking(ErrRAGStatus)
return fmt.Errorf("embedding failed: %w", err)
case result, ok := <-resultCh:
if !ok {
// All results processed
resultCh = nil
r.logger.Debug("result channel closed", "batches_processed", batchesProcessed, "total_batches", totalBatches)
continue
}
// Store result in buffer
resultsBuffer[result.batchIndex] = result
// Process buffered results in order
for {
if res, exists := resultsBuffer[nextExpectedBatch]; exists {
// Write this batch to database
if err := r.writeBatchToStorage(ctx, res, filename); err != nil {
cancel()
return err
}
batchesProcessed++
// Send progress update
statusMsg := fmt.Sprintf("processed batch %d/%d", batchesProcessed, totalBatches)
r.sendStatusNonBlocking(statusMsg)
delete(resultsBuffer, nextExpectedBatch)
nextExpectedBatch++
} else {
break
}
}
default:
// No channels ready, check for deadlock conditions
if resultCh == nil && nextExpectedBatch < totalBatches {
// Missing batch results after result channel closed
r.logger.Error("missing batch results",
"expected", totalBatches,
"received", nextExpectedBatch,
"missing", totalBatches-nextExpectedBatch)
// Wait a short time for any delayed errors, then cancel
select {
case <-time.After(5 * time.Second):
cancel()
return fmt.Errorf("missing batch results: expected %d, got %d", totalBatches, nextExpectedBatch)
case <-ctx.Done():
return ctx.Err()
case err := <-errorCh:
cancel()
r.logger.Error("embedding worker failed after result channel closed", "error", err)
r.sendStatusNonBlocking(ErrRAGStatus)
return fmt.Errorf("embedding failed: %w", err)
}
}
// If we reach here, no deadlock yet, just busy loop prevention
time.Sleep(100 * time.Millisecond)
}
// Check if we're done
if resultCh == nil && nextExpectedBatch >= totalBatches {
r.logger.Debug("all batches processed successfully", "total", totalBatches)
break
}
}
r.logger.Debug("finished writing vectors", "batches", batchesProcessed)
r.resetIdleTimer()
r.sendStatusNonBlocking(FinishedRAGStatus)
return nil
}
// embeddingWorker processes batch embedding tasks
func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan batchTask, resultCh chan<- batchResult, errorCh chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
r.logger.Debug("embedding worker started", "worker", workerID)
// Panic recovery to ensure worker doesn't crash silently
defer func() {
if rec := recover(); rec != nil {
r.logger.Error("embedding worker panicked", "worker", workerID, "panic", rec)
// Try to send error, but don't block if channel is full
select {
case errorCh <- fmt.Errorf("worker %d panicked: %v", workerID, rec):
default:
r.logger.Warn("error channel full, dropping panic error", "worker", workerID)
}
}
}()
for task := range taskCh {
select {
case <-ctx.Done():
r.logger.Debug("embedding worker cancelled", "worker", workerID)
return
default:
}
r.logger.Debug("worker processing batch", "worker", workerID, "batch", task.batchIndex, "paragraphs", len(task.paragraphs), "total_batches", task.totalBatches)
// Skip empty batches
if len(task.paragraphs) == 0 {
select {
case resultCh <- batchResult{
batchIndex: task.batchIndex,
embeddings: nil,
paragraphs: nil,
filename: task.filename,
}:
case <-ctx.Done():
r.logger.Debug("embedding worker cancelled while sending empty batch", "worker", workerID)
return
}
r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex)
continue
}
// Embed with retry for API embedder
embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3)
if err != nil {
// Try to send error, but don't block indefinitely
select {
case errorCh <- fmt.Errorf("worker %d batch %d: %w", workerID, task.batchIndex, err):
case <-ctx.Done():
r.logger.Debug("embedding worker cancelled while sending error", "worker", workerID)
}
return
}
// Send result with context awareness
select {
case resultCh <- batchResult{
batchIndex: task.batchIndex,
embeddings: embeddings,
paragraphs: task.paragraphs,
filename: task.filename,
}:
case <-ctx.Done():
r.logger.Debug("embedding worker cancelled while sending result", "worker", workerID)
return
}
r.logger.Debug("worker completed batch", "worker", workerID, "batch", task.batchIndex, "embeddings", len(embeddings))
}
r.logger.Debug("embedding worker finished", "worker", workerID)
}
// embedWithRetry attempts embedding with exponential backoff for API embedder
func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) {
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
// Exponential backoff
backoff := time.Duration(attempt*attempt) * time.Second
if backoff > 10*time.Second {
backoff = 10 * time.Second
}
select {
case <-time.After(backoff):
case <-ctx.Done():
return nil, ctx.Err()
}
r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries)
}
embeddings, err := r.embedder.EmbedSlice(paragraphs)
if err == nil {
// Validate embedding count
if len(embeddings) != len(paragraphs) {
return nil, fmt.Errorf("embedding count mismatch: expected %d, got %d", len(paragraphs), len(embeddings))
}
return embeddings, nil
}
lastErr = err
// Only retry for API embedder errors (network/timeout)
// For ONNX embedder, fail fast
if _, isAPI := r.embedder.(*APIEmbedder); !isAPI {
break
}
}
return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr)
}
// writeBatchToStorage writes a single batch of vectors to the database
func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filename string) error {
if len(result.embeddings) == 0 {
// Empty batch, skip
return nil
}
// Check context before starting
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Build all vectors for batch write
vectors := make([]*models.VectorRow, 0, len(result.paragraphs))
for j, text := range result.paragraphs {
vectors = append(vectors, &models.VectorRow{
Embeddings: result.embeddings[j],
RawText: text,
Slug: fmt.Sprintf("%s_%d_%d", filename, result.batchIndex+1, j),
FileName: filename,
})
}
// Write all vectors in a single transaction
if err := r.storage.WriteVectors(vectors); err != nil {
r.logger.Error("failed to write vectors batch to DB", "error", err, "batch", result.batchIndex+1, "size", len(vectors))
r.sendStatusNonBlocking(ErrRAGStatus)
return fmt.Errorf("failed to write vectors batch: %w", err)
}
r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs))
return nil
}
func (r *RAG) LineToVector(line string) ([]float32, error) {
r.resetIdleTimer()
return r.embedder.Embed(line)
}
func (r *RAG) searchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
r.resetIdleTimer()
return r.storage.SearchClosest(emb.Embedding, limit)
}
func (r *RAG) searchKeyword(query string, limit int) ([]models.VectorRow, error) {
r.resetIdleTimer()
sanitized := sanitizeFTSQuery(query)
return r.storage.SearchKeyword(sanitized, limit)
}
func (r *RAG) ListLoaded() ([]string, error) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.storage.ListFiles()
}
func (r *RAG) RemoveFile(filename string) error {
r.mu.Lock()
defer r.mu.Unlock()
r.resetIdleTimer()
return r.storage.RemoveEmbByFileName(filename)
}
var (
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
)
func (r *RAG) RefineQuery(query string) string {
original := query
query = strings.TrimSpace(query)
if len(query) == 0 {
return original
}
if len(query) <= 3 {
return original
}
// If query already contains double quotes, assume it's a phrase query and skip refinement
if strings.Contains(query, "\"") {
return original
}
query = strings.ToLower(query)
words := strings.Fields(query)
if len(words) >= 3 {
// Detect phrases and protect words that are part of phrases
phrases := detectPhrases(query)
protectedWords := make(map[string]bool)
for _, phrase := range phrases {
for _, word := range strings.Fields(phrase) {
protectedWords[word] = true
}
}
// Remove stop words that are not protected
for _, stopWord := range stopWords {
if protectedWords[stopWord] {
continue
}
wordPattern := `\b` + stopWord + `\b`
re := regexp.MustCompile(wordPattern)
query = re.ReplaceAllString(query, "")
}
}
query = strings.TrimSpace(query)
if len(query) < 5 {
return original
}
if queryRefinementPattern.MatchString(original) {
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
cleaned = strings.TrimSpace(cleaned)
if len(cleaned) >= 5 {
return cleaned
}
}
query = r.extractImportantPhrases(query)
if len(query) < 5 {
return original
}
return query
}
func (r *RAG) extractImportantPhrases(query string) string {
words := strings.Fields(query)
var important []string
for _, word := range words {
word = strings.Trim(word, ".,!?;:'\"()[]{}")
isImportant := false
for _, kw := range importantKeywords {
if strings.Contains(strings.ToLower(word), kw) {
isImportant = true
break
}
}
if isImportant || len(word) >= 3 {
important = append(important, word)
}
}
if len(important) == 0 {
return query
}
return strings.Join(important, " ")
}
func (r *RAG) GenerateQueryVariations(query string) []string {
variations := []string{query}
if len(query) < 5 {
return variations
}
parts := strings.Fields(query)
if len(parts) == 0 {
return variations
}
// Get loaded filenames to filter out filename terms
filenames, err := r.storage.ListFiles()
if err == nil && len(filenames) > 0 {
// Convert to lowercase for case-insensitive matching
lowerFilenames := make([]string, len(filenames))
for i, f := range filenames {
lowerFilenames[i] = strings.ToLower(f)
}
filteredParts := make([]string, 0, len(parts))
for _, part := range parts {
partLower := strings.ToLower(part)
skip := false
for _, fn := range lowerFilenames {
if strings.Contains(fn, partLower) || strings.Contains(partLower, fn) {
skip = true
break
}
}
if !skip {
filteredParts = append(filteredParts, part)
}
}
// If filteredParts not empty and different from original, add filtered query
if len(filteredParts) > 0 && len(filteredParts) != len(parts) {
filteredQuery := strings.Join(filteredParts, " ")
if len(filteredQuery) >= 5 {
variations = append(variations, filteredQuery)
}
}
}
if len(parts) >= 2 {
trimmed := strings.Join(parts[:len(parts)-1], " ")
if len(trimmed) >= 5 {
variations = append(variations, trimmed)
}
}
if len(parts) >= 2 {
trimmed := strings.Join(parts[1:], " ")
if len(trimmed) >= 5 {
variations = append(variations, trimmed)
}
}
if !strings.HasSuffix(query, " explanation") {
variations = append(variations, query+" explanation")
}
if !strings.HasPrefix(query, "what is ") {
variations = append(variations, "what is "+query)
}
if !strings.HasSuffix(query, " details") {
variations = append(variations, query+" details")
}
if !strings.HasSuffix(query, " summary") {
variations = append(variations, query+" summary")
}
// Add phrase-quoted variations for better FTS5 matching
phrases := detectPhrases(query)
if len(phrases) > 0 {
// Sort phrases by length descending to prioritize longer phrases
sort.Slice(phrases, func(i, j int) bool {
return len(phrases[i]) > len(phrases[j])
})
// Create a version with all phrases quoted
quotedQuery := query
for _, phrase := range phrases {
// Only quote if not already quoted
quotedPhrase := "\"" + phrase + "\""
if !strings.Contains(strings.ToLower(quotedQuery), strings.ToLower(quotedPhrase)) {
// Case-insensitive replacement of phrase with quoted version
re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase)
}
}
// Disabled malformed quoted query for now
// if quotedQuery != query {
// variations = append(variations, quotedQuery)
// }
// Also add individual phrase variations for short queries
if len(phrases) <= 5 {
for _, phrase := range phrases {
// Create a focused query with just this phrase quoted
// Keep original context but emphasize this phrase
quotedPhrase := "\"" + phrase + "\""
re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
focusedQuery := re.ReplaceAllString(query, quotedPhrase)
if focusedQuery != query && focusedQuery != quotedQuery {
variations = append(variations, focusedQuery)
}
// Add the phrase alone (quoted) as a separate variation
variations = append(variations, quotedPhrase)
}
}
}
return variations
}
func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
phraseCount := len(detectPhrases(query))
type scoredResult struct {
row models.VectorRow
distance float32
phraseMatches int
}
scored := make([]scoredResult, 0, len(results))
for i := range results {
row := results[i]
score := float32(0)
rawTextLower := strings.ToLower(row.RawText)
queryLower := strings.ToLower(query)
if strings.Contains(rawTextLower, queryLower) {
score += 10
}
queryWords := strings.Fields(queryLower)
matchCount := 0
for _, word := range queryWords {
if len(word) > 2 && strings.Contains(rawTextLower, word) {
matchCount++
}
}
if len(queryWords) > 0 {
score += float32(matchCount) / float32(len(queryWords)) * 5
}
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
score += 3
}
// Phrase match bonus: extra points for containing detected phrases
phraseMatches := countPhraseMatches(row.RawText, query)
if phraseMatches > 0 {
// Significant bonus per phrase to prioritize exact phrase matches
r.logger.Debug("phrase match bonus", "slug", row.Slug, "phraseMatches", phraseMatches, "score", score)
score += float32(phraseMatches) * 100
}
// Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results,
// boost score to promote narrative continuity
adjacentCount := 0
for _, other := range results {
if other.Slug == row.Slug {
continue
}
if areSlugsAdjacent(row.Slug, other.Slug) {
adjacentCount++
}
}
if adjacentCount > 0 {
// Bonus per adjacent chunk, but diminishing returns
score += float32(adjacentCount) * 4
}
distance := row.Distance - score/100
scored = append(scored, scoredResult{row: row, distance: distance, phraseMatches: phraseMatches})
}
sort.Slice(scored, func(i, j int) bool {
return scored[i].distance < scored[j].distance
})
unique := make([]models.VectorRow, 0)
seen := make(map[string]bool)
maxPerFile := 2
if phraseCount > 0 {
maxPerFile = 10
}
fileCounts := make(map[string]int)
for i := range scored {
if !seen[scored[i].row.Slug] {
// Allow phrase-matching chunks to bypass per-file limit (up to +5 extra)
allowed := fileCounts[scored[i].row.FileName] < maxPerFile
if !allowed && scored[i].phraseMatches > 0 {
// If chunk has phrase matches, allow extra slots (up to maxPerFile + 5)
allowed = fileCounts[scored[i].row.FileName] < maxPerFile+5
}
if !allowed {
continue
}
seen[scored[i].row.Slug] = true
fileCounts[scored[i].row.FileName]++
unique = append(unique, scored[i].row)
}
}
if len(unique) > 30 {
unique = unique[:30]
}
return unique
}
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
r.mu.RLock()
defer r.mu.RUnlock()
r.resetIdleTimer()
if len(results) == 0 {
return "No relevant information found in the vector database.", nil
}
var contextBuilder strings.Builder
contextBuilder.WriteString("User Query: ")
contextBuilder.WriteString(query)
contextBuilder.WriteString("\n\nRetrieved Context:\n")
for i, row := range results {
fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName)
contextBuilder.WriteString(row.RawText)
contextBuilder.WriteString("\n\n")
}
contextBuilder.WriteString("Instructions: ")
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
contextBuilder.WriteString("Extract only the most relevant information. ")
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
contextBuilder.WriteString("Cite sources by filename when relevant. ")
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
synthesisPrompt := contextBuilder.String()
emb, err := r.LineToVector(synthesisPrompt)
if err != nil {
r.logger.Error("failed to embed synthesis prompt", "error", err)
return "", err
}
embResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0,
}
topResults, err := r.searchEmb(embResp, 1)
if err != nil {
r.logger.Error("failed to search for synthesis context", "error", err)
return "", err
}
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
return topResults[0].RawText, nil
}
var finalAnswer strings.Builder
finalAnswer.WriteString("Based on the retrieved context:\n\n")
for i, row := range results {
if i >= 5 {
break
}
fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))
}
return finalAnswer.String(), nil
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
r.mu.RLock()
defer r.mu.RUnlock()
r.resetIdleTimer()
refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined)
r.logger.Debug("query variations", "original", query, "refined", refined, "variations", variations)
// Collect embedding search results from all variations
var embResults []models.VectorRow
seen := make(map[string]bool)
for _, q := range variations {
emb, err := r.LineToVector(q)
if err != nil {
r.logger.Error("failed to embed query variation", "error", err, "query", q)
continue
}
embResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0,
}
results, err := r.searchEmb(embResp, limit*2) // Get more candidates
if err != nil {
r.logger.Error("failed to search embeddings", "error", err, "query", q)
continue
}
for _, row := range results {
if !seen[row.Slug] {
seen[row.Slug] = true
embResults = append(embResults, row)
}
}
}
// Sort embedding results by distance (lower is better)
sort.Slice(embResults, func(i, j int) bool {
return embResults[i].Distance < embResults[j].Distance
})
// Perform keyword search on all variations
var kwResults []models.VectorRow
seenKw := make(map[string]bool)
for _, q := range variations {
results, err := r.searchKeyword(q, limit)
if err != nil {
r.logger.Debug("keyword search failed for variation", "error", err, "query", q)
continue
}
for _, row := range results {
if !seenKw[row.Slug] {
seenKw[row.Slug] = true
kwResults = append(kwResults, row)
}
}
}
// Sort keyword results by distance (lower is better)
sort.Slice(kwResults, func(i, j int) bool {
return kwResults[i].Distance < kwResults[j].Distance
})
// Combine using Reciprocal Rank Fusion (RRF)
// Use smaller K for phrase-heavy queries to give more weight to top ranks
phraseCount := len(detectPhrases(query))
rrfK := 60.0
if phraseCount > 0 {
rrfK = 30.0
}
r.logger.Debug("RRF parameters", "phraseCount", phraseCount, "rrfK", rrfK, "query", query)
type scoredRow struct {
row models.VectorRow
score float64
}
scoreMap := make(map[string]float64)
// Add embedding results
for rank, row := range embResults {
score := 1.0 / (float64(rank) + rrfK)
scoreMap[row.Slug] += score
if row.Slug == "kjv_bible.epub_1786_0" {
r.logger.Debug("target chunk embedding rank", "rank", rank, "score", score)
}
}
// Add keyword results with weight boost when phrases are present
kwWeight := 1.0
if phraseCount > 0 {
kwWeight = 100.0
}
r.logger.Debug("keyword weight", "kwWeight", kwWeight, "phraseCount", phraseCount)
for rank, row := range kwResults {
score := kwWeight * (1.0 / (float64(rank) + rrfK))
scoreMap[row.Slug] += score
if row.Slug == "kjv_bible.epub_1786_0" {
r.logger.Debug("target chunk keyword rank", "rank", rank, "score", score, "kwWeight", kwWeight, "rrfK", rrfK)
}
// Ensure row exists in combined results
if _, exists := seen[row.Slug]; !exists {
embResults = append(embResults, row)
}
}
// Create slice of scored rows
scoredRows := make([]scoredRow, 0, len(embResults))
for _, row := range embResults {
score := scoreMap[row.Slug]
scoredRows = append(scoredRows, scoredRow{row: row, score: score})
}
// Debug: log scores for target chunk and top chunks
if strings.Contains(strings.ToLower(query), "bald") || strings.Contains(strings.ToLower(query), "she bears") {
for _, sr := range scoredRows {
if sr.row.Slug == "kjv_bible.epub_1786_0" {
r.logger.Debug("target chunk score", "slug", sr.row.Slug, "score", sr.score, "distance", sr.row.Distance)
}
}
// Log top 5 scores
for i := 0; i < len(scoredRows) && i < 5; i++ {
r.logger.Debug("top scored row", "rank", i+1, "slug", scoredRows[i].row.Slug, "score", scoredRows[i].score, "distance", scoredRows[i].row.Distance)
}
}
// Sort by descending RRF score
sort.Slice(scoredRows, func(i, j int) bool {
return scoredRows[i].score > scoredRows[j].score
})
// Take top limit
if len(scoredRows) > limit {
scoredRows = scoredRows[:limit]
}
// Convert back to VectorRow
finalResults := make([]models.VectorRow, len(scoredRows))
for i, sr := range scoredRows {
finalResults[i] = sr.row
}
// Apply reranking heuristics
reranked := r.RerankResults(finalResults, query)
return reranked, nil
}
var (
ragInstance *RAG
ragOnce sync.Once
)
func (r *RAG) FallbackMessage() string {
return r.fallbackMsg
}
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
var err error
ragOnce.Do(func() {
if c == nil || l == nil || s == nil {
return
}
ragInstance, err = New(l, s, c)
})
return err
}
func GetInstance() *RAG {
return ragInstance
}
func (r *RAG) resetIdleTimer() {
r.idleMu.Lock()
defer r.idleMu.Unlock()
if r.idleTimer != nil {
r.idleTimer.Stop()
}
r.idleTimer = time.AfterFunc(r.idleTimeout, func() {
r.freeONNXMemory()
})
}
func (r *RAG) freeONNXMemory() {
r.mu.Lock()
defer r.mu.Unlock()
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
if err := onnx.Destroy(); err != nil {
r.logger.Error("failed to free ONNX memory", "error", err)
} else {
r.logger.Info("freed ONNX VRAM after idle timeout")
}
}
}
func (r *RAG) Destroy() {
r.mu.Lock()
defer r.mu.Unlock()
if r.idleTimer != nil {
r.idleTimer.Stop()
r.idleTimer = nil
}
if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
if err := onnx.Destroy(); err != nil {
r.logger.Error("failed to destroy ONNX embedder", "error", err)
}
}
}
// SetEmbedderForTesting replaces the internal embedder with a mock.
// This function is only available when compiling with the "test" build tag.
func (r *RAG) SetEmbedderForTesting(e Embedder) {
r.mu.Lock()
defer r.mu.Unlock()
r.embedder = e
}