Feat (rag): hybrid search attempt
This commit is contained in:
176
rag/rag.go
176
rag/rag.go
@@ -73,6 +73,74 @@ func wordCounter(sentence string) int {
|
||||
return len(strings.Split(strings.TrimSpace(sentence), " "))
|
||||
}
|
||||
|
||||
func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
|
||||
if len(sentences) == 0 {
|
||||
return nil
|
||||
}
|
||||
if overlapWords >= wordLimit {
|
||||
overlapWords = wordLimit / 2
|
||||
}
|
||||
var chunks []string
|
||||
i := 0
|
||||
for i < len(sentences) {
|
||||
var chunkWords []string
|
||||
wordCount := 0
|
||||
j := i
|
||||
for j < len(sentences) && wordCount <= int(wordLimit) {
|
||||
sentence := sentences[j]
|
||||
words := strings.Fields(sentence)
|
||||
chunkWords = append(chunkWords, sentence)
|
||||
wordCount += len(words)
|
||||
j++
|
||||
// If this sentence alone exceeds limit, still include it and stop
|
||||
if wordCount > int(wordLimit) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(chunkWords) == 0 {
|
||||
break
|
||||
}
|
||||
chunk := strings.Join(chunkWords, " ")
|
||||
chunks = append(chunks, chunk)
|
||||
if j >= len(sentences) {
|
||||
break
|
||||
}
|
||||
// Move i forward by skipping overlap
|
||||
if overlapWords == 0 {
|
||||
i = j
|
||||
continue
|
||||
}
|
||||
// Calculate how many sentences to skip to achieve overlapWords
|
||||
overlapRemaining := int(overlapWords)
|
||||
newI := i
|
||||
for newI < j && overlapRemaining > 0 {
|
||||
words := len(strings.Fields(sentences[newI]))
|
||||
overlapRemaining -= words
|
||||
if overlapRemaining >= 0 {
|
||||
newI++
|
||||
}
|
||||
}
|
||||
if newI == i {
|
||||
newI = j
|
||||
}
|
||||
i = newI
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
func sanitizeFTSQuery(query string) string {
|
||||
// Remove double quotes and other problematic characters for FTS5
|
||||
query = strings.ReplaceAll(query, "\"", " ")
|
||||
query = strings.ReplaceAll(query, "'", " ")
|
||||
query = strings.ReplaceAll(query, ";", " ")
|
||||
query = strings.ReplaceAll(query, "\\", " ")
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return "*" // match all
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (r *RAG) LoadRAG(fpath string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
@@ -95,31 +163,8 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
for i, s := range sentences {
|
||||
sents[i] = s.Text
|
||||
}
|
||||
// Group sentences into paragraphs based on word limit
|
||||
paragraphs := []string{}
|
||||
par := strings.Builder{}
|
||||
for i := 0; i < len(sents); i++ {
|
||||
if strings.TrimSpace(sents[i]) != "" {
|
||||
if par.Len() > 0 {
|
||||
par.WriteString(" ")
|
||||
}
|
||||
par.WriteString(sents[i])
|
||||
}
|
||||
if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
|
||||
paragraph := strings.TrimSpace(par.String())
|
||||
if paragraph != "" {
|
||||
paragraphs = append(paragraphs, paragraph)
|
||||
}
|
||||
par.Reset()
|
||||
}
|
||||
}
|
||||
// Handle any remaining content in the paragraph buffer
|
||||
if par.Len() > 0 {
|
||||
paragraph := strings.TrimSpace(par.String())
|
||||
if paragraph != "" {
|
||||
paragraphs = append(paragraphs, paragraph)
|
||||
}
|
||||
}
|
||||
// Create chunks with overlap
|
||||
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
|
||||
// Adjust batch size if needed
|
||||
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
|
||||
r.cfg.RAGBatchSize = len(paragraphs)
|
||||
@@ -205,9 +250,15 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||
return r.embedder.Embed(line)
|
||||
}
|
||||
|
||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
|
||||
r.resetIdleTimer()
|
||||
return r.storage.SearchClosest(emb.Embedding)
|
||||
return r.storage.SearchClosest(emb.Embedding, limit)
|
||||
}
|
||||
|
||||
func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) {
|
||||
r.resetIdleTimer()
|
||||
sanitized := sanitizeFTSQuery(query)
|
||||
return r.storage.SearchKeyword(sanitized, limit)
|
||||
}
|
||||
|
||||
func (r *RAG) ListLoaded() ([]string, error) {
|
||||
@@ -393,7 +444,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
|
||||
Embedding: emb,
|
||||
Index: 0,
|
||||
}
|
||||
topResults, err := r.SearchEmb(embResp)
|
||||
topResults, err := r.SearchEmb(embResp, 1)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to search for synthesis context", "error", err)
|
||||
return "", err
|
||||
@@ -422,7 +473,9 @@ func truncateString(s string, maxLen int) string {
|
||||
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
refined := r.RefineQuery(query)
|
||||
variations := r.GenerateQueryVariations(refined)
|
||||
allResults := make([]models.VectorRow, 0)
|
||||
|
||||
// Collect embedding search results from all variations
|
||||
var embResults []models.VectorRow
|
||||
seen := make(map[string]bool)
|
||||
for _, q := range variations {
|
||||
emb, err := r.LineToVector(q)
|
||||
@@ -430,29 +483,78 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||
r.logger.Error("failed to embed query variation", "error", err, "query", q)
|
||||
continue
|
||||
}
|
||||
|
||||
embResp := &models.EmbeddingResp{
|
||||
Embedding: emb,
|
||||
Index: 0,
|
||||
}
|
||||
|
||||
results, err := r.SearchEmb(embResp)
|
||||
results, err := r.SearchEmb(embResp, limit*2) // Get more candidates
|
||||
if err != nil {
|
||||
r.logger.Error("failed to search embeddings", "error", err, "query", q)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, row := range results {
|
||||
if !seen[row.Slug] {
|
||||
seen[row.Slug] = true
|
||||
allResults = append(allResults, row)
|
||||
embResults = append(embResults, row)
|
||||
}
|
||||
}
|
||||
}
|
||||
reranked := r.RerankResults(allResults, query)
|
||||
if len(reranked) > limit {
|
||||
reranked = reranked[:limit]
|
||||
// Sort embedding results by distance (lower is better)
|
||||
sort.Slice(embResults, func(i, j int) bool {
|
||||
return embResults[i].Distance < embResults[j].Distance
|
||||
})
|
||||
|
||||
// Perform keyword search
|
||||
kwResults, err := r.SearchKeyword(refined, limit*2)
|
||||
if err != nil {
|
||||
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
|
||||
kwResults = nil
|
||||
}
|
||||
// Sort keyword results by distance (already sorted by BM25 score)
|
||||
// kwResults already sorted by distance (lower is better)
|
||||
|
||||
// Combine using Reciprocal Rank Fusion (RRF)
|
||||
const rrfK = 60
|
||||
type scoredRow struct {
|
||||
row models.VectorRow
|
||||
score float64
|
||||
}
|
||||
scoreMap := make(map[string]float64)
|
||||
// Add embedding results
|
||||
for rank, row := range embResults {
|
||||
score := 1.0 / (float64(rank) + rrfK)
|
||||
scoreMap[row.Slug] += score
|
||||
}
|
||||
// Add keyword results
|
||||
for rank, row := range kwResults {
|
||||
score := 1.0 / (float64(rank) + rrfK)
|
||||
scoreMap[row.Slug] += score
|
||||
// Ensure row exists in combined results
|
||||
if _, exists := seen[row.Slug]; !exists {
|
||||
embResults = append(embResults, row)
|
||||
}
|
||||
}
|
||||
// Create slice of scored rows
|
||||
scoredRows := make([]scoredRow, 0, len(embResults))
|
||||
for _, row := range embResults {
|
||||
score := scoreMap[row.Slug]
|
||||
scoredRows = append(scoredRows, scoredRow{row: row, score: score})
|
||||
}
|
||||
// Sort by descending RRF score
|
||||
sort.Slice(scoredRows, func(i, j int) bool {
|
||||
return scoredRows[i].score > scoredRows[j].score
|
||||
})
|
||||
// Take top limit
|
||||
if len(scoredRows) > limit {
|
||||
scoredRows = scoredRows[:limit]
|
||||
}
|
||||
// Convert back to VectorRow
|
||||
finalResults := make([]models.VectorRow, len(scoredRows))
|
||||
for i, sr := range scoredRows {
|
||||
finalResults[i] = sr.row
|
||||
}
|
||||
// Apply reranking heuristics
|
||||
reranked := r.RerankResults(finalResults, query)
|
||||
return reranked, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user