Feat (rag): hybrid search attempt
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"gf-lt/models"
|
||||
"sort"
|
||||
"unsafe"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
@@ -11,7 +12,7 @@ import (
|
||||
|
||||
type VectorRepo interface {
|
||||
WriteVector(*models.VectorRow) error
|
||||
SearchClosest(q []float32) ([]models.VectorRow, error)
|
||||
SearchClosest(q []float32, limit int) ([]models.VectorRow, error)
|
||||
ListFiles() ([]string, error)
|
||||
RemoveEmbByFileName(filename string) error
|
||||
DB() *sqlx.DB
|
||||
@@ -79,7 +80,7 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
||||
func (p ProviderSQL) SearchClosest(q []float32, limit int) ([]models.VectorRow, error) {
|
||||
tableName, err := fetchTableName(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -94,7 +95,7 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
||||
vector models.VectorRow
|
||||
distance float32
|
||||
}
|
||||
var topResults []SearchResult
|
||||
var allResults []SearchResult
|
||||
for rows.Next() {
|
||||
var (
|
||||
embeddingsBlob []byte
|
||||
@@ -119,28 +120,19 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
||||
},
|
||||
distance: distance,
|
||||
}
|
||||
|
||||
// Add to top results and maintain only top results
|
||||
topResults = append(topResults, result)
|
||||
|
||||
// Sort and keep only top results
|
||||
// We'll keep the top 3 closest vectors
|
||||
if len(topResults) > 3 {
|
||||
// Simple sort and truncate to maintain only 3 best matches
|
||||
for i := 0; i < len(topResults); i++ {
|
||||
for j := i + 1; j < len(topResults); j++ {
|
||||
if topResults[i].distance > topResults[j].distance {
|
||||
topResults[i], topResults[j] = topResults[j], topResults[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
topResults = topResults[:3]
|
||||
}
|
||||
allResults = append(allResults, result)
|
||||
}
|
||||
// Sort by distance
|
||||
sort.Slice(allResults, func(i, j int) bool {
|
||||
return allResults[i].distance < allResults[j].distance
|
||||
})
|
||||
// Truncate to limit
|
||||
if len(allResults) > limit {
|
||||
allResults = allResults[:limit]
|
||||
}
|
||||
|
||||
// Convert back to VectorRow slice
|
||||
results := make([]models.VectorRow, len(topResults))
|
||||
for i, result := range topResults {
|
||||
results := make([]models.VectorRow, len(allResults))
|
||||
for i, result := range allResults {
|
||||
result.vector.Distance = result.distance
|
||||
results[i] = result.vector
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user