Refactor: rag to sep package
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"elefant/models"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"unsafe"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
|
||||
@@ -12,7 +11,7 @@ import (
|
||||
|
||||
type VectorRepo interface {
|
||||
WriteVector(*models.VectorRow) error
|
||||
SearchClosest(q []float32) (*models.VectorRow, error)
|
||||
SearchClosest(q []float32) ([]models.VectorRow, error)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -79,7 +78,11 @@ func decodeUnsafe(bs []byte) []float32 {
|
||||
return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4)
|
||||
}
|
||||
|
||||
func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) {
|
||||
func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
||||
tableName, err := fetchTableName(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt, _, err := p.s3Conn.Prepare(
|
||||
fmt.Sprintf(`SELECT
|
||||
id,
|
||||
@@ -91,35 +94,35 @@ func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) {
|
||||
WHERE embedding MATCH ?
|
||||
ORDER BY distance
|
||||
LIMIT 4
|
||||
`, vecTableName))
|
||||
`, tableName))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
query, err := sqlite_vec.SerializeFloat32(q[:])
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
if err := stmt.BindBlob(1, query); err != nil {
|
||||
p.logger.Error("failed to bind", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
resp := make([]models.VectorRow, 4)
|
||||
i := 0
|
||||
resp := []models.VectorRow{}
|
||||
for stmt.Step() {
|
||||
resp[i].ID = uint32(stmt.ColumnInt64(0))
|
||||
resp[i].Distance = float32(stmt.ColumnFloat(1))
|
||||
res := models.VectorRow{}
|
||||
res.ID = uint32(stmt.ColumnInt64(0))
|
||||
res.Distance = float32(stmt.ColumnFloat(1))
|
||||
emb := stmt.ColumnRawText(2)
|
||||
resp[i].Embeddings = decodeUnsafe(emb)
|
||||
resp[i].Slug = stmt.ColumnText(3)
|
||||
resp[i].RawText = stmt.ColumnText(4)
|
||||
i++
|
||||
res.Embeddings = decodeUnsafe(emb)
|
||||
res.Slug = stmt.ColumnText(3)
|
||||
res.RawText = stmt.ColumnText(4)
|
||||
resp = append(resp, res)
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user