Fix: RAG updates
This commit is contained in:
@@ -12,17 +12,19 @@ import (
|
||||
type VectorRepo interface {
|
||||
WriteVector(*models.VectorRow) error
|
||||
SearchClosest(q []float32) ([]models.VectorRow, error)
|
||||
ListFiles() ([]string, error)
|
||||
RemoveEmbByFileName(filename string) error
|
||||
}
|
||||
|
||||
var (
|
||||
vecTableName = "embeddings"
|
||||
vecTableName384 = "embeddings_384"
|
||||
vecTableName5120 = "embeddings_5120"
|
||||
vecTableName384 = "embeddings_384"
|
||||
)
|
||||
|
||||
func fetchTableName(emb []float32) (string, error) {
|
||||
switch len(emb) {
|
||||
case 5120:
|
||||
return vecTableName, nil
|
||||
return vecTableName5120, nil
|
||||
case 384:
|
||||
return vecTableName384, nil
|
||||
default:
|
||||
@@ -36,7 +38,7 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
|
||||
return err
|
||||
}
|
||||
stmt, _, err := p.s3Conn.Prepare(
|
||||
fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", tableName))
|
||||
fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName))
|
||||
if err != nil {
|
||||
p.logger.Error("failed to prep a stmt", "error", err)
|
||||
return err
|
||||
@@ -66,6 +68,10 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
|
||||
p.logger.Error("failed to bind", "error", err)
|
||||
return err
|
||||
}
|
||||
if err := stmt.BindText(4, row.FileName); err != nil {
|
||||
p.logger.Error("failed to bind", "error", err)
|
||||
return err
|
||||
}
|
||||
err = stmt.Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -87,11 +93,12 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
||||
distance,
|
||||
embedding,
|
||||
slug,
|
||||
raw_text
|
||||
raw_text,
|
||||
filename
|
||||
FROM %s
|
||||
WHERE embedding MATCH ?
|
||||
ORDER BY distance
|
||||
LIMIT 4
|
||||
LIMIT 3
|
||||
`, tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -112,6 +119,7 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
||||
res.Embeddings = decodeUnsafe(emb)
|
||||
res.Slug = stmt.ColumnText(2)
|
||||
res.RawText = stmt.ColumnText(3)
|
||||
res.FileName = stmt.ColumnText(4)
|
||||
resp = append(resp, res)
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
@@ -123,3 +131,33 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p ProviderSQL) ListFiles() ([]string, error) {
|
||||
q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384)
|
||||
stmt, _, err := p.s3Conn.Prepare(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
resp := []string{}
|
||||
for stmt.Step() {
|
||||
resp = append(resp, stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
||||
q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384)
|
||||
stmt, _, err := p.s3Conn.Prepare(q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if err := stmt.BindText(1, filename); err != nil {
|
||||
return err
|
||||
}
|
||||
return stmt.Exec()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user