Feat (rag): hybrid search attempt
This commit is contained in:
@@ -29,6 +29,7 @@ AutoCleanToolCallsFromCtx = false
|
|||||||
# rag settings
|
# rag settings
|
||||||
RAGBatchSize = 1
|
RAGBatchSize = 1
|
||||||
RAGWordLimit = 80
|
RAGWordLimit = 80
|
||||||
|
RAGOverlapWords = 16
|
||||||
RAGDir = "ragimport"
|
RAGDir = "ragimport"
|
||||||
# extra tts
|
# extra tts
|
||||||
TTS_ENABLED = false
|
TTS_ENABLED = false
|
||||||
|
|||||||
@@ -40,9 +40,10 @@ type Config struct {
|
|||||||
EmbedTokenizerPath string `toml:"EmbedTokenizerPath"`
|
EmbedTokenizerPath string `toml:"EmbedTokenizerPath"`
|
||||||
EmbedDims int `toml:"EmbedDims"`
|
EmbedDims int `toml:"EmbedDims"`
|
||||||
// rag settings
|
// rag settings
|
||||||
RAGDir string `toml:"RAGDir"`
|
RAGDir string `toml:"RAGDir"`
|
||||||
RAGBatchSize int `toml:"RAGBatchSize"`
|
RAGBatchSize int `toml:"RAGBatchSize"`
|
||||||
RAGWordLimit uint32 `toml:"RAGWordLimit"`
|
RAGWordLimit uint32 `toml:"RAGWordLimit"`
|
||||||
|
RAGOverlapWords uint32 `toml:"RAGOverlapWords"`
|
||||||
// deepseek
|
// deepseek
|
||||||
DeepSeekChatAPI string `toml:"DeepSeekChatAPI"`
|
DeepSeekChatAPI string `toml:"DeepSeekChatAPI"`
|
||||||
DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"`
|
DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"`
|
||||||
|
|||||||
176
rag/rag.go
176
rag/rag.go
@@ -73,6 +73,74 @@ func wordCounter(sentence string) int {
|
|||||||
return len(strings.Split(strings.TrimSpace(sentence), " "))
|
return len(strings.Split(strings.TrimSpace(sentence), " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
|
||||||
|
if len(sentences) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if overlapWords >= wordLimit {
|
||||||
|
overlapWords = wordLimit / 2
|
||||||
|
}
|
||||||
|
var chunks []string
|
||||||
|
i := 0
|
||||||
|
for i < len(sentences) {
|
||||||
|
var chunkWords []string
|
||||||
|
wordCount := 0
|
||||||
|
j := i
|
||||||
|
for j < len(sentences) && wordCount <= int(wordLimit) {
|
||||||
|
sentence := sentences[j]
|
||||||
|
words := strings.Fields(sentence)
|
||||||
|
chunkWords = append(chunkWords, sentence)
|
||||||
|
wordCount += len(words)
|
||||||
|
j++
|
||||||
|
// If this sentence alone exceeds limit, still include it and stop
|
||||||
|
if wordCount > int(wordLimit) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(chunkWords) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
chunk := strings.Join(chunkWords, " ")
|
||||||
|
chunks = append(chunks, chunk)
|
||||||
|
if j >= len(sentences) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Move i forward by skipping overlap
|
||||||
|
if overlapWords == 0 {
|
||||||
|
i = j
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Calculate how many sentences to skip to achieve overlapWords
|
||||||
|
overlapRemaining := int(overlapWords)
|
||||||
|
newI := i
|
||||||
|
for newI < j && overlapRemaining > 0 {
|
||||||
|
words := len(strings.Fields(sentences[newI]))
|
||||||
|
overlapRemaining -= words
|
||||||
|
if overlapRemaining >= 0 {
|
||||||
|
newI++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newI == i {
|
||||||
|
newI = j
|
||||||
|
}
|
||||||
|
i = newI
|
||||||
|
}
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeFTSQuery(query string) string {
|
||||||
|
// Remove double quotes and other problematic characters for FTS5
|
||||||
|
query = strings.ReplaceAll(query, "\"", " ")
|
||||||
|
query = strings.ReplaceAll(query, "'", " ")
|
||||||
|
query = strings.ReplaceAll(query, ";", " ")
|
||||||
|
query = strings.ReplaceAll(query, "\\", " ")
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if query == "" {
|
||||||
|
return "*" // match all
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RAG) LoadRAG(fpath string) error {
|
func (r *RAG) LoadRAG(fpath string) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -95,31 +163,8 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
for i, s := range sentences {
|
for i, s := range sentences {
|
||||||
sents[i] = s.Text
|
sents[i] = s.Text
|
||||||
}
|
}
|
||||||
// Group sentences into paragraphs based on word limit
|
// Create chunks with overlap
|
||||||
paragraphs := []string{}
|
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
|
||||||
par := strings.Builder{}
|
|
||||||
for i := 0; i < len(sents); i++ {
|
|
||||||
if strings.TrimSpace(sents[i]) != "" {
|
|
||||||
if par.Len() > 0 {
|
|
||||||
par.WriteString(" ")
|
|
||||||
}
|
|
||||||
par.WriteString(sents[i])
|
|
||||||
}
|
|
||||||
if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
|
|
||||||
paragraph := strings.TrimSpace(par.String())
|
|
||||||
if paragraph != "" {
|
|
||||||
paragraphs = append(paragraphs, paragraph)
|
|
||||||
}
|
|
||||||
par.Reset()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Handle any remaining content in the paragraph buffer
|
|
||||||
if par.Len() > 0 {
|
|
||||||
paragraph := strings.TrimSpace(par.String())
|
|
||||||
if paragraph != "" {
|
|
||||||
paragraphs = append(paragraphs, paragraph)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Adjust batch size if needed
|
// Adjust batch size if needed
|
||||||
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
|
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
|
||||||
r.cfg.RAGBatchSize = len(paragraphs)
|
r.cfg.RAGBatchSize = len(paragraphs)
|
||||||
@@ -205,9 +250,15 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
|||||||
return r.embedder.Embed(line)
|
return r.embedder.Embed(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
|
func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
|
||||||
r.resetIdleTimer()
|
r.resetIdleTimer()
|
||||||
return r.storage.SearchClosest(emb.Embedding)
|
return r.storage.SearchClosest(emb.Embedding, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) {
|
||||||
|
r.resetIdleTimer()
|
||||||
|
sanitized := sanitizeFTSQuery(query)
|
||||||
|
return r.storage.SearchKeyword(sanitized, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) ListLoaded() ([]string, error) {
|
func (r *RAG) ListLoaded() ([]string, error) {
|
||||||
@@ -393,7 +444,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
|
|||||||
Embedding: emb,
|
Embedding: emb,
|
||||||
Index: 0,
|
Index: 0,
|
||||||
}
|
}
|
||||||
topResults, err := r.SearchEmb(embResp)
|
topResults, err := r.SearchEmb(embResp, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to search for synthesis context", "error", err)
|
r.logger.Error("failed to search for synthesis context", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -422,7 +473,9 @@ func truncateString(s string, maxLen int) string {
|
|||||||
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||||
refined := r.RefineQuery(query)
|
refined := r.RefineQuery(query)
|
||||||
variations := r.GenerateQueryVariations(refined)
|
variations := r.GenerateQueryVariations(refined)
|
||||||
allResults := make([]models.VectorRow, 0)
|
|
||||||
|
// Collect embedding search results from all variations
|
||||||
|
var embResults []models.VectorRow
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
for _, q := range variations {
|
for _, q := range variations {
|
||||||
emb, err := r.LineToVector(q)
|
emb, err := r.LineToVector(q)
|
||||||
@@ -430,29 +483,78 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
|||||||
r.logger.Error("failed to embed query variation", "error", err, "query", q)
|
r.logger.Error("failed to embed query variation", "error", err, "query", q)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
embResp := &models.EmbeddingResp{
|
embResp := &models.EmbeddingResp{
|
||||||
Embedding: emb,
|
Embedding: emb,
|
||||||
Index: 0,
|
Index: 0,
|
||||||
}
|
}
|
||||||
|
results, err := r.SearchEmb(embResp, limit*2) // Get more candidates
|
||||||
results, err := r.SearchEmb(embResp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to search embeddings", "error", err, "query", q)
|
r.logger.Error("failed to search embeddings", "error", err, "query", q)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, row := range results {
|
for _, row := range results {
|
||||||
if !seen[row.Slug] {
|
if !seen[row.Slug] {
|
||||||
seen[row.Slug] = true
|
seen[row.Slug] = true
|
||||||
allResults = append(allResults, row)
|
embResults = append(embResults, row)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
reranked := r.RerankResults(allResults, query)
|
// Sort embedding results by distance (lower is better)
|
||||||
if len(reranked) > limit {
|
sort.Slice(embResults, func(i, j int) bool {
|
||||||
reranked = reranked[:limit]
|
return embResults[i].Distance < embResults[j].Distance
|
||||||
|
})
|
||||||
|
|
||||||
|
// Perform keyword search
|
||||||
|
kwResults, err := r.SearchKeyword(refined, limit*2)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
|
||||||
|
kwResults = nil
|
||||||
}
|
}
|
||||||
|
// Sort keyword results by distance (already sorted by BM25 score)
|
||||||
|
// kwResults already sorted by distance (lower is better)
|
||||||
|
|
||||||
|
// Combine using Reciprocal Rank Fusion (RRF)
|
||||||
|
const rrfK = 60
|
||||||
|
type scoredRow struct {
|
||||||
|
row models.VectorRow
|
||||||
|
score float64
|
||||||
|
}
|
||||||
|
scoreMap := make(map[string]float64)
|
||||||
|
// Add embedding results
|
||||||
|
for rank, row := range embResults {
|
||||||
|
score := 1.0 / (float64(rank) + rrfK)
|
||||||
|
scoreMap[row.Slug] += score
|
||||||
|
}
|
||||||
|
// Add keyword results
|
||||||
|
for rank, row := range kwResults {
|
||||||
|
score := 1.0 / (float64(rank) + rrfK)
|
||||||
|
scoreMap[row.Slug] += score
|
||||||
|
// Ensure row exists in combined results
|
||||||
|
if _, exists := seen[row.Slug]; !exists {
|
||||||
|
embResults = append(embResults, row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Create slice of scored rows
|
||||||
|
scoredRows := make([]scoredRow, 0, len(embResults))
|
||||||
|
for _, row := range embResults {
|
||||||
|
score := scoreMap[row.Slug]
|
||||||
|
scoredRows = append(scoredRows, scoredRow{row: row, score: score})
|
||||||
|
}
|
||||||
|
// Sort by descending RRF score
|
||||||
|
sort.Slice(scoredRows, func(i, j int) bool {
|
||||||
|
return scoredRows[i].score > scoredRows[j].score
|
||||||
|
})
|
||||||
|
// Take top limit
|
||||||
|
if len(scoredRows) > limit {
|
||||||
|
scoredRows = scoredRows[:limit]
|
||||||
|
}
|
||||||
|
// Convert back to VectorRow
|
||||||
|
finalResults := make([]models.VectorRow, len(scoredRows))
|
||||||
|
for i, sr := range scoredRows {
|
||||||
|
finalResults[i] = sr.row
|
||||||
|
}
|
||||||
|
// Apply reranking heuristics
|
||||||
|
reranked := r.RerankResults(finalResults, query)
|
||||||
return reranked, nil
|
return reranked, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
119
rag/storage.go
119
rag/storage.go
@@ -62,6 +62,18 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
embeddingSize := len(row.Embeddings)
|
||||||
|
|
||||||
|
// Start transaction
|
||||||
|
tx, err := vs.sqlxDB.Beginx()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Serialize the embeddings to binary
|
// Serialize the embeddings to binary
|
||||||
serializedEmbeddings := SerializeVector(row.Embeddings)
|
serializedEmbeddings := SerializeVector(row.Embeddings)
|
||||||
@@ -69,10 +81,23 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
|||||||
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
||||||
tableName,
|
tableName,
|
||||||
)
|
)
|
||||||
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
if _, err := tx.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
||||||
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Insert into FTS table
|
||||||
|
ftsQuery := `INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES (?, ?, ?, ?)`
|
||||||
|
if _, err := tx.Exec(ftsQuery, row.Slug, row.RawText, row.FileName, embeddingSize); err != nil {
|
||||||
|
vs.logger.Error("failed to write to FTS table", "error", err, "slug", row.Slug)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
vs.logger.Error("failed to commit transaction", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,16 +123,15 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SearchClosest finds vectors closest to the query vector using efficient cosine similarity calculation
|
// SearchClosest finds vectors closest to the query vector using efficient cosine similarity calculation
|
||||||
func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, error) {
|
func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.VectorRow, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 10
|
||||||
|
}
|
||||||
tableName, err := vs.getTableName(query)
|
tableName, err := vs.getTableName(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// For better performance, instead of loading all vectors at once,
|
|
||||||
// we'll implement batching and potentially add L2 distance-based pre-filtering
|
|
||||||
// since cosine similarity is related to L2 distance for normalized vectors
|
|
||||||
|
|
||||||
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
||||||
rows, err := vs.sqlxDB.Query(querySQL)
|
rows, err := vs.sqlxDB.Query(querySQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -115,13 +139,11 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
// Use a min-heap or simple slice to keep track of top 3 closest vectors
|
|
||||||
type SearchResult struct {
|
type SearchResult struct {
|
||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
var topResults []SearchResult
|
var topResults []SearchResult
|
||||||
// Process vectors one by one to avoid loading everything into memory
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
embeddingsBlob []byte
|
embeddingsBlob []byte
|
||||||
@@ -134,10 +156,8 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
storedEmbeddings := DeserializeVector(embeddingsBlob)
|
storedEmbeddings := DeserializeVector(embeddingsBlob)
|
||||||
|
|
||||||
// Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar)
|
|
||||||
similarity := cosineSimilarity(query, storedEmbeddings)
|
similarity := cosineSimilarity(query, storedEmbeddings)
|
||||||
distance := 1 - similarity // Convert to distance where 0 is most similar
|
distance := 1 - similarity
|
||||||
|
|
||||||
result := SearchResult{
|
result := SearchResult{
|
||||||
vector: models.VectorRow{
|
vector: models.VectorRow{
|
||||||
@@ -149,20 +169,15 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
distance: distance,
|
distance: distance,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add to top results and maintain only top 3
|
|
||||||
topResults = append(topResults, result)
|
topResults = append(topResults, result)
|
||||||
|
|
||||||
// Sort and keep only top 3
|
|
||||||
sort.Slice(topResults, func(i, j int) bool {
|
sort.Slice(topResults, func(i, j int) bool {
|
||||||
return topResults[i].distance < topResults[j].distance
|
return topResults[i].distance < topResults[j].distance
|
||||||
})
|
})
|
||||||
|
if len(topResults) > limit {
|
||||||
if len(topResults) > 3 {
|
topResults = topResults[:limit]
|
||||||
topResults = topResults[:3] // Keep only closest 3
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert back to VectorRow slice
|
|
||||||
results := make([]models.VectorRow, 0, len(topResults))
|
results := make([]models.VectorRow, 0, len(topResults))
|
||||||
for _, result := range topResults {
|
for _, result := range topResults {
|
||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
@@ -171,6 +186,70 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetVectorBySlug retrieves a vector row by its slug
|
||||||
|
func (vs *VectorStorage) GetVectorBySlug(slug string) (*models.VectorRow, error) {
|
||||||
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
|
for _, size := range embeddingSizes {
|
||||||
|
table := fmt.Sprintf("embeddings_%d", size)
|
||||||
|
query := fmt.Sprintf("SELECT embeddings, slug, raw_text, filename FROM %s WHERE slug = ?", table)
|
||||||
|
row := vs.sqlxDB.QueryRow(query, slug)
|
||||||
|
var (
|
||||||
|
embeddingsBlob []byte
|
||||||
|
retrievedSlug, rawText, fileName string
|
||||||
|
)
|
||||||
|
if err := row.Scan(&embeddingsBlob, &retrievedSlug, &rawText, &fileName); err != nil {
|
||||||
|
// No row in this table, continue to next size
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
storedEmbeddings := DeserializeVector(embeddingsBlob)
|
||||||
|
return &models.VectorRow{
|
||||||
|
Embeddings: storedEmbeddings,
|
||||||
|
Slug: retrievedSlug,
|
||||||
|
RawText: rawText,
|
||||||
|
FileName: fileName,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("vector with slug %s not found", slug)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchKeyword performs full-text search using FTS5
|
||||||
|
func (vs *VectorStorage) SearchKeyword(query string, limit int) ([]models.VectorRow, error) {
|
||||||
|
// Use FTS5 bm25 ranking. bm25 returns negative values where more negative is better.
|
||||||
|
// We'll order by bm25 (ascending) and limit.
|
||||||
|
ftsQuery := `SELECT slug, raw_text, filename, bm25(fts_embeddings) as score
|
||||||
|
FROM fts_embeddings
|
||||||
|
WHERE fts_embeddings MATCH ?
|
||||||
|
ORDER BY score
|
||||||
|
LIMIT ?`
|
||||||
|
rows, err := vs.sqlxDB.Query(ftsQuery, query, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("FTS search failed: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var results []models.VectorRow
|
||||||
|
for rows.Next() {
|
||||||
|
var slug, rawText, fileName string
|
||||||
|
var score float64
|
||||||
|
if err := rows.Scan(&slug, &rawText, &fileName, &score); err != nil {
|
||||||
|
vs.logger.Error("failed to scan FTS row", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Convert BM25 score to distance-like metric (lower is better)
|
||||||
|
// BM25 is negative, more negative is better. We'll normalize to positive distance.
|
||||||
|
distance := float32(-score) // Make positive (since score is negative)
|
||||||
|
if distance < 0 {
|
||||||
|
distance = 0
|
||||||
|
}
|
||||||
|
results = append(results, models.VectorRow{
|
||||||
|
Slug: slug,
|
||||||
|
RawText: rawText,
|
||||||
|
FileName: fileName,
|
||||||
|
Distance: distance,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ListFiles returns a list of all loaded files
|
// ListFiles returns a list of all loaded files
|
||||||
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
||||||
fileLists := make([][]string, 0)
|
fileLists := make([][]string, 0)
|
||||||
@@ -215,6 +294,10 @@ func (vs *VectorStorage) ListFiles() ([]string, error) {
|
|||||||
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
||||||
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
||||||
var errors []string
|
var errors []string
|
||||||
|
// Delete from FTS table first
|
||||||
|
if _, err := vs.sqlxDB.Exec("DELETE FROM fts_embeddings WHERE filename = ?", filename); err != nil {
|
||||||
|
errors = append(errors, err.Error())
|
||||||
|
}
|
||||||
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
for _, size := range embeddingSizes {
|
for _, size := range embeddingSizes {
|
||||||
table := fmt.Sprintf("embeddings_%d", size)
|
table := fmt.Sprintf("embeddings_%d", size)
|
||||||
|
|||||||
2
storage/migrations/003_add_fts.down.sql
Normal file
2
storage/migrations/003_add_fts.down.sql
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
-- Drop FTS5 virtual table
|
||||||
|
DROP TABLE IF EXISTS fts_embeddings;
|
||||||
15
storage/migrations/003_add_fts.up.sql
Normal file
15
storage/migrations/003_add_fts.up.sql
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
-- Create FTS5 virtual table for full-text search
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS fts_embeddings USING fts5(
|
||||||
|
slug UNINDEXED,
|
||||||
|
raw_text,
|
||||||
|
filename UNINDEXED,
|
||||||
|
embedding_size UNINDEXED,
|
||||||
|
tokenize='porter unicode61' -- Use porter stemmer and unicode61 tokenizer
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Create triggers to maintain FTS table when embeddings are inserted/deleted
|
||||||
|
-- Note: We'll handle inserts/deletes programmatically for simplicity
|
||||||
|
-- but triggers could be added here if needed.
|
||||||
|
|
||||||
|
-- Indexes for performance (FTS5 manages its own indexes)
|
||||||
|
-- No additional indexes needed for FTS5 virtual table.
|
||||||
2
storage/migrations/004_populate_fts.down.sql
Normal file
2
storage/migrations/004_populate_fts.down.sql
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
-- Clear FTS table (optional)
|
||||||
|
DELETE FROM fts_embeddings;
|
||||||
26
storage/migrations/004_populate_fts.up.sql
Normal file
26
storage/migrations/004_populate_fts.up.sql
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
-- Populate FTS table with existing embeddings
|
||||||
|
DELETE FROM fts_embeddings;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 384 FROM embeddings_384;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 768 FROM embeddings_768;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 1024 FROM embeddings_1024;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 1536 FROM embeddings_1536;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 2048 FROM embeddings_2048;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 3072 FROM embeddings_3072;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 4096 FROM embeddings_4096;
|
||||||
|
|
||||||
|
INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size)
|
||||||
|
SELECT slug, raw_text, filename, 5120 FROM embeddings_5120;
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
|
"sort"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
@@ -11,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type VectorRepo interface {
|
type VectorRepo interface {
|
||||||
WriteVector(*models.VectorRow) error
|
WriteVector(*models.VectorRow) error
|
||||||
SearchClosest(q []float32) ([]models.VectorRow, error)
|
SearchClosest(q []float32, limit int) ([]models.VectorRow, error)
|
||||||
ListFiles() ([]string, error)
|
ListFiles() ([]string, error)
|
||||||
RemoveEmbByFileName(filename string) error
|
RemoveEmbByFileName(filename string) error
|
||||||
DB() *sqlx.DB
|
DB() *sqlx.DB
|
||||||
@@ -79,7 +80,7 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
func (p ProviderSQL) SearchClosest(q []float32, limit int) ([]models.VectorRow, error) {
|
||||||
tableName, err := fetchTableName(q)
|
tableName, err := fetchTableName(q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -94,7 +95,7 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
var topResults []SearchResult
|
var allResults []SearchResult
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
embeddingsBlob []byte
|
embeddingsBlob []byte
|
||||||
@@ -119,28 +120,19 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
},
|
},
|
||||||
distance: distance,
|
distance: distance,
|
||||||
}
|
}
|
||||||
|
allResults = append(allResults, result)
|
||||||
// Add to top results and maintain only top results
|
}
|
||||||
topResults = append(topResults, result)
|
// Sort by distance
|
||||||
|
sort.Slice(allResults, func(i, j int) bool {
|
||||||
// Sort and keep only top results
|
return allResults[i].distance < allResults[j].distance
|
||||||
// We'll keep the top 3 closest vectors
|
})
|
||||||
if len(topResults) > 3 {
|
// Truncate to limit
|
||||||
// Simple sort and truncate to maintain only 3 best matches
|
if len(allResults) > limit {
|
||||||
for i := 0; i < len(topResults); i++ {
|
allResults = allResults[:limit]
|
||||||
for j := i + 1; j < len(topResults); j++ {
|
|
||||||
if topResults[i].distance > topResults[j].distance {
|
|
||||||
topResults[i], topResults[j] = topResults[j], topResults[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
topResults = topResults[:3]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert back to VectorRow slice
|
// Convert back to VectorRow slice
|
||||||
results := make([]models.VectorRow, len(topResults))
|
results := make([]models.VectorRow, len(allResults))
|
||||||
for i, result := range topResults {
|
for i, result := range allResults {
|
||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
results[i] = result.vector
|
results[i] = result.vector
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user