Feat (rag): hybrid search attempt

This commit is contained in:
Grail Finder
2026-03-06 11:20:50 +03:00
parent 822cc48834
commit f9866bcf5a
9 changed files with 305 additions and 81 deletions

View File

@@ -73,6 +73,74 @@ func wordCounter(sentence string) int {
return len(strings.Split(strings.TrimSpace(sentence), " "))
}
func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
if len(sentences) == 0 {
return nil
}
if overlapWords >= wordLimit {
overlapWords = wordLimit / 2
}
var chunks []string
i := 0
for i < len(sentences) {
var chunkWords []string
wordCount := 0
j := i
for j < len(sentences) && wordCount <= int(wordLimit) {
sentence := sentences[j]
words := strings.Fields(sentence)
chunkWords = append(chunkWords, sentence)
wordCount += len(words)
j++
// If this sentence alone exceeds limit, still include it and stop
if wordCount > int(wordLimit) {
break
}
}
if len(chunkWords) == 0 {
break
}
chunk := strings.Join(chunkWords, " ")
chunks = append(chunks, chunk)
if j >= len(sentences) {
break
}
// Move i forward by skipping overlap
if overlapWords == 0 {
i = j
continue
}
// Calculate how many sentences to skip to achieve overlapWords
overlapRemaining := int(overlapWords)
newI := i
for newI < j && overlapRemaining > 0 {
words := len(strings.Fields(sentences[newI]))
overlapRemaining -= words
if overlapRemaining >= 0 {
newI++
}
}
if newI == i {
newI = j
}
i = newI
}
return chunks
}
func sanitizeFTSQuery(query string) string {
// Remove double quotes and other problematic characters for FTS5
query = strings.ReplaceAll(query, "\"", " ")
query = strings.ReplaceAll(query, "'", " ")
query = strings.ReplaceAll(query, ";", " ")
query = strings.ReplaceAll(query, "\\", " ")
query = strings.TrimSpace(query)
if query == "" {
return "*" // match all
}
return query
}
func (r *RAG) LoadRAG(fpath string) error {
r.mu.Lock()
defer r.mu.Unlock()
@@ -95,31 +163,8 @@ func (r *RAG) LoadRAG(fpath string) error {
for i, s := range sentences {
sents[i] = s.Text
}
// Group sentences into paragraphs based on word limit
paragraphs := []string{}
par := strings.Builder{}
for i := 0; i < len(sents); i++ {
if strings.TrimSpace(sents[i]) != "" {
if par.Len() > 0 {
par.WriteString(" ")
}
par.WriteString(sents[i])
}
if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
paragraph := strings.TrimSpace(par.String())
if paragraph != "" {
paragraphs = append(paragraphs, paragraph)
}
par.Reset()
}
}
// Handle any remaining content in the paragraph buffer
if par.Len() > 0 {
paragraph := strings.TrimSpace(par.String())
if paragraph != "" {
paragraphs = append(paragraphs, paragraph)
}
}
// Create chunks with overlap
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
// Adjust batch size if needed
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
r.cfg.RAGBatchSize = len(paragraphs)
@@ -205,9 +250,15 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
return r.embedder.Embed(line)
}
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
r.resetIdleTimer()
return r.storage.SearchClosest(emb.Embedding)
return r.storage.SearchClosest(emb.Embedding, limit)
}
func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) {
r.resetIdleTimer()
sanitized := sanitizeFTSQuery(query)
return r.storage.SearchKeyword(sanitized, limit)
}
func (r *RAG) ListLoaded() ([]string, error) {
@@ -393,7 +444,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
Embedding: emb,
Index: 0,
}
topResults, err := r.SearchEmb(embResp)
topResults, err := r.SearchEmb(embResp, 1)
if err != nil {
r.logger.Error("failed to search for synthesis context", "error", err)
return "", err
@@ -422,7 +473,9 @@ func truncateString(s string, maxLen int) string {
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined)
allResults := make([]models.VectorRow, 0)
// Collect embedding search results from all variations
var embResults []models.VectorRow
seen := make(map[string]bool)
for _, q := range variations {
emb, err := r.LineToVector(q)
@@ -430,29 +483,78 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
r.logger.Error("failed to embed query variation", "error", err, "query", q)
continue
}
embResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0,
}
results, err := r.SearchEmb(embResp)
results, err := r.SearchEmb(embResp, limit*2) // Get more candidates
if err != nil {
r.logger.Error("failed to search embeddings", "error", err, "query", q)
continue
}
for _, row := range results {
if !seen[row.Slug] {
seen[row.Slug] = true
allResults = append(allResults, row)
embResults = append(embResults, row)
}
}
}
reranked := r.RerankResults(allResults, query)
if len(reranked) > limit {
reranked = reranked[:limit]
// Sort embedding results by distance (lower is better)
sort.Slice(embResults, func(i, j int) bool {
return embResults[i].Distance < embResults[j].Distance
})
// Perform keyword search
kwResults, err := r.SearchKeyword(refined, limit*2)
if err != nil {
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
kwResults = nil
}
// Sort keyword results by distance (already sorted by BM25 score)
// kwResults already sorted by distance (lower is better)
// Combine using Reciprocal Rank Fusion (RRF)
const rrfK = 60
type scoredRow struct {
row models.VectorRow
score float64
}
scoreMap := make(map[string]float64)
// Add embedding results
for rank, row := range embResults {
score := 1.0 / (float64(rank) + rrfK)
scoreMap[row.Slug] += score
}
// Add keyword results
for rank, row := range kwResults {
score := 1.0 / (float64(rank) + rrfK)
scoreMap[row.Slug] += score
// Ensure row exists in combined results
if _, exists := seen[row.Slug]; !exists {
embResults = append(embResults, row)
}
}
// Create slice of scored rows
scoredRows := make([]scoredRow, 0, len(embResults))
for _, row := range embResults {
score := scoreMap[row.Slug]
scoredRows = append(scoredRows, scoredRow{row: row, score: score})
}
// Sort by descending RRF score
sort.Slice(scoredRows, func(i, j int) bool {
return scoredRows[i].score > scoredRows[j].score
})
// Take top limit
if len(scoredRows) > limit {
scoredRows = scoredRows[:limit]
}
// Convert back to VectorRow
finalResults := make([]models.VectorRow, len(scoredRows))
for i, sr := range scoredRows {
finalResults[i] = sr.row
}
// Apply reranking heuristics
reranked := r.RerankResults(finalResults, query)
return reranked, nil
}