Enha (rag): async writes
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sugarme/tokenizer"
|
"github.com/sugarme/tokenizer"
|
||||||
"github.com/sugarme/tokenizer/pretrained"
|
"github.com/sugarme/tokenizer/pretrained"
|
||||||
@@ -33,8 +34,10 @@ type APIEmbedder struct {
|
|||||||
func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
|
func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
|
||||||
return &APIEmbedder{
|
return &APIEmbedder{
|
||||||
logger: l,
|
logger: l,
|
||||||
client: &http.Client{},
|
client: &http.Client{
|
||||||
cfg: cfg,
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
446
rag/rag.go
446
rag/rag.go
@@ -1,6 +1,7 @@
|
|||||||
package rag
|
package rag
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gf-lt/config"
|
"gf-lt/config"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"path"
|
"path"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -17,9 +19,14 @@ import (
|
|||||||
"github.com/neurosnap/sentences/english"
|
"github.com/neurosnap/sentences/english"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// batchTimeout is the maximum time allowed for embedding a single batch
|
||||||
|
batchTimeout = 2 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Status messages for TUI integration
|
// Status messages for TUI integration
|
||||||
LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking
|
LongJobStatusCh = make(chan string, 100) // Increased buffer size for parallel batch updates
|
||||||
FinishedRAGStatus = "finished loading RAG file; press Enter"
|
FinishedRAGStatus = "finished loading RAG file; press Enter"
|
||||||
LoadedFileRAGStatus = "loaded file"
|
LoadedFileRAGStatus = "loaded file"
|
||||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||||
@@ -31,12 +38,38 @@ type RAG struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
embedder Embedder
|
embedder Embedder
|
||||||
storage *VectorStorage
|
storage *VectorStorage
|
||||||
mu sync.Mutex
|
mu sync.RWMutex
|
||||||
|
idleMu sync.Mutex
|
||||||
fallbackMsg string
|
fallbackMsg string
|
||||||
idleTimer *time.Timer
|
idleTimer *time.Timer
|
||||||
idleTimeout time.Duration
|
idleTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// batchTask represents a single batch to be embedded
|
||||||
|
type batchTask struct {
|
||||||
|
batchIndex int
|
||||||
|
paragraphs []string
|
||||||
|
filename string
|
||||||
|
totalBatches int
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchResult represents the result of embedding a batch
|
||||||
|
type batchResult struct {
|
||||||
|
batchIndex int
|
||||||
|
embeddings [][]float32
|
||||||
|
paragraphs []string
|
||||||
|
filename string
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendStatusNonBlocking sends a status message without blocking
|
||||||
|
func (r *RAG) sendStatusNonBlocking(status string) {
|
||||||
|
select {
|
||||||
|
case LongJobStatusCh <- status:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
|
||||||
var embedder Embedder
|
var embedder Embedder
|
||||||
var fallbackMsg string
|
var fallbackMsg string
|
||||||
@@ -142,18 +175,22 @@ func sanitizeFTSQuery(query string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) LoadRAG(fpath string) error {
|
func (r *RAG) LoadRAG(fpath string) error {
|
||||||
|
return r.LoadRAGWithContext(context.Background(), fpath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
fileText, err := ExtractText(fpath)
|
fileText, err := ExtractText(fpath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.logger.Debug("rag: loaded file", "fp", fpath)
|
r.logger.Debug("rag: loaded file", "fp", fpath)
|
||||||
select {
|
|
||||||
case LongJobStatusCh <- LoadedFileRAGStatus:
|
// Send initial status (non-blocking with retry)
|
||||||
default:
|
r.sendStatusNonBlocking(LoadedFileRAGStatus)
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
|
|
||||||
}
|
|
||||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
tokenizer, err := english.NewSentenceTokenizer(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -163,6 +200,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
for i, s := range sentences {
|
for i, s := range sentences {
|
||||||
sents[i] = s.Text
|
sents[i] = s.Text
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create chunks with overlap
|
// Create chunks with overlap
|
||||||
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
|
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
|
||||||
// Adjust batch size if needed
|
// Adjust batch size if needed
|
||||||
@@ -172,76 +210,332 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
if len(paragraphs) == 0 {
|
if len(paragraphs) == 0 {
|
||||||
return errors.New("no valid paragraphs found in file")
|
return errors.New("no valid paragraphs found in file")
|
||||||
}
|
}
|
||||||
// Process paragraphs in batches synchronously
|
|
||||||
batchCount := 0
|
totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize
|
||||||
for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize {
|
r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize)
|
||||||
end := i + r.cfg.RAGBatchSize
|
|
||||||
if end > len(paragraphs) {
|
// Determine concurrency level
|
||||||
end = len(paragraphs)
|
concurrency := runtime.NumCPU()
|
||||||
}
|
if concurrency > totalBatches {
|
||||||
batch := paragraphs[i:end]
|
concurrency = totalBatches
|
||||||
batchCount++
|
}
|
||||||
// Filter empty paragraphs
|
if concurrency < 1 {
|
||||||
nonEmptyBatch := make([]string, 0, len(batch))
|
concurrency = 1
|
||||||
for _, p := range batch {
|
}
|
||||||
if strings.TrimSpace(p) != "" {
|
// If using ONNX embedder, limit concurrency to 1 due to mutex serialization
|
||||||
nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
|
isONNX := false
|
||||||
|
if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX {
|
||||||
|
concurrency = 1
|
||||||
|
}
|
||||||
|
embedderType := "API"
|
||||||
|
if isONNX {
|
||||||
|
embedderType = "ONNX"
|
||||||
|
}
|
||||||
|
r.logger.Debug("parallel embedding setup",
|
||||||
|
"total_batches", totalBatches,
|
||||||
|
"concurrency", concurrency,
|
||||||
|
"embedder", embedderType,
|
||||||
|
"batch_size", r.cfg.RAGBatchSize)
|
||||||
|
|
||||||
|
// Create context with timeout (30 minutes) and cancellation for error handling
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Channels for task distribution and results
|
||||||
|
taskCh := make(chan batchTask, totalBatches)
|
||||||
|
resultCh := make(chan batchResult, totalBatches)
|
||||||
|
errorCh := make(chan error, totalBatches)
|
||||||
|
|
||||||
|
// Start worker goroutines
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for w := 0; w < concurrency; w++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go r.embeddingWorker(ctx, w, taskCh, resultCh, errorCh, &wg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close task channel after all tasks are sent (by separate goroutine)
|
||||||
|
go func() {
|
||||||
|
// Ensure task channel is closed when this goroutine exits
|
||||||
|
defer close(taskCh)
|
||||||
|
r.logger.Debug("task distributor started", "total_batches", totalBatches)
|
||||||
|
|
||||||
|
for i := 0; i < totalBatches; i++ {
|
||||||
|
start := i * r.cfg.RAGBatchSize
|
||||||
|
end := start + r.cfg.RAGBatchSize
|
||||||
|
if end > len(paragraphs) {
|
||||||
|
end = len(paragraphs)
|
||||||
|
}
|
||||||
|
batch := paragraphs[start:end]
|
||||||
|
|
||||||
|
// Filter empty paragraphs
|
||||||
|
nonEmptyBatch := make([]string, 0, len(batch))
|
||||||
|
for _, p := range batch {
|
||||||
|
if strings.TrimSpace(p) != "" {
|
||||||
|
nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task := batchTask{
|
||||||
|
batchIndex: i,
|
||||||
|
paragraphs: nonEmptyBatch,
|
||||||
|
filename: path.Base(fpath),
|
||||||
|
totalBatches: totalBatches,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case taskCh <- task:
|
||||||
|
r.logger.Debug("task distributor sent batch", "batch", i, "paragraphs", len(nonEmptyBatch))
|
||||||
|
case <-ctx.Done():
|
||||||
|
r.logger.Debug("task distributor cancelled", "batches_sent", i+1, "total_batches", totalBatches)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(nonEmptyBatch) == 0 {
|
r.logger.Debug("task distributor finished", "batches_sent", totalBatches)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for workers to finish and close result channel
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(resultCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Process results in order and write to database
|
||||||
|
nextExpectedBatch := 0
|
||||||
|
resultsBuffer := make(map[int]batchResult)
|
||||||
|
filename := path.Base(fpath)
|
||||||
|
batchesProcessed := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
|
||||||
|
case err := <-errorCh:
|
||||||
|
// First error from any worker, cancel everything
|
||||||
|
cancel()
|
||||||
|
r.logger.Error("embedding worker failed", "error", err)
|
||||||
|
r.sendStatusNonBlocking(ErrRAGStatus)
|
||||||
|
return fmt.Errorf("embedding failed: %w", err)
|
||||||
|
|
||||||
|
case result, ok := <-resultCh:
|
||||||
|
if !ok {
|
||||||
|
// All results processed
|
||||||
|
resultCh = nil
|
||||||
|
r.logger.Debug("result channel closed", "batches_processed", batchesProcessed, "total_batches", totalBatches)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store result in buffer
|
||||||
|
resultsBuffer[result.batchIndex] = result
|
||||||
|
|
||||||
|
// Process buffered results in order
|
||||||
|
for {
|
||||||
|
if res, exists := resultsBuffer[nextExpectedBatch]; exists {
|
||||||
|
// Write this batch to database
|
||||||
|
if err := r.writeBatchToStorage(ctx, res, filename); err != nil {
|
||||||
|
cancel()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
batchesProcessed++
|
||||||
|
// Send progress update
|
||||||
|
statusMsg := fmt.Sprintf("processed batch %d/%d", batchesProcessed, totalBatches)
|
||||||
|
r.sendStatusNonBlocking(statusMsg)
|
||||||
|
|
||||||
|
delete(resultsBuffer, nextExpectedBatch)
|
||||||
|
nextExpectedBatch++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// No channels ready, check for deadlock conditions
|
||||||
|
if resultCh == nil && nextExpectedBatch < totalBatches {
|
||||||
|
// Missing batch results after result channel closed
|
||||||
|
r.logger.Error("missing batch results",
|
||||||
|
"expected", totalBatches,
|
||||||
|
"received", nextExpectedBatch,
|
||||||
|
"missing", totalBatches-nextExpectedBatch)
|
||||||
|
|
||||||
|
// Wait a short time for any delayed errors, then cancel
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
cancel()
|
||||||
|
return fmt.Errorf("missing batch results: expected %d, got %d", totalBatches, nextExpectedBatch)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-errorCh:
|
||||||
|
cancel()
|
||||||
|
r.logger.Error("embedding worker failed after result channel closed", "error", err)
|
||||||
|
r.sendStatusNonBlocking(ErrRAGStatus)
|
||||||
|
return fmt.Errorf("embedding failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we reach here, no deadlock yet, just busy loop prevention
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we're done
|
||||||
|
if resultCh == nil && nextExpectedBatch >= totalBatches {
|
||||||
|
r.logger.Debug("all batches processed successfully", "total", totalBatches)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Debug("finished writing vectors", "batches", batchesProcessed)
|
||||||
|
r.resetIdleTimer()
|
||||||
|
r.sendStatusNonBlocking(FinishedRAGStatus)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// embeddingWorker processes batch embedding tasks
|
||||||
|
func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan batchTask, resultCh chan<- batchResult, errorCh chan<- error, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
r.logger.Debug("embedding worker started", "worker", workerID)
|
||||||
|
|
||||||
|
// Panic recovery to ensure worker doesn't crash silently
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
r.logger.Error("embedding worker panicked", "worker", workerID, "panic", rec)
|
||||||
|
// Try to send error, but don't block if channel is full
|
||||||
|
select {
|
||||||
|
case errorCh <- fmt.Errorf("worker %d panicked: %v", workerID, rec):
|
||||||
|
default:
|
||||||
|
r.logger.Warn("error channel full, dropping panic error", "worker", workerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for task := range taskCh {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
r.logger.Debug("embedding worker cancelled", "worker", workerID)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
r.logger.Debug("worker processing batch", "worker", workerID, "batch", task.batchIndex, "paragraphs", len(task.paragraphs), "total_batches", task.totalBatches)
|
||||||
|
|
||||||
|
// Skip empty batches
|
||||||
|
if len(task.paragraphs) == 0 {
|
||||||
|
select {
|
||||||
|
case resultCh <- batchResult{
|
||||||
|
batchIndex: task.batchIndex,
|
||||||
|
embeddings: nil,
|
||||||
|
paragraphs: nil,
|
||||||
|
filename: task.filename,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
r.logger.Debug("embedding worker cancelled while sending empty batch", "worker", workerID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Embed the batch
|
|
||||||
embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch)
|
// Embed with retry for API embedder
|
||||||
|
embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to embed batch", "error", err, "batch", batchCount)
|
// Try to send error, but don't block indefinitely
|
||||||
select {
|
select {
|
||||||
case LongJobStatusCh <- ErrRAGStatus:
|
case errorCh <- fmt.Errorf("worker %d batch %d: %w", workerID, task.batchIndex, err):
|
||||||
default:
|
case <-ctx.Done():
|
||||||
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
r.logger.Debug("embedding worker cancelled while sending error", "worker", workerID)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("failed to embed batch %d: %w", batchCount, err)
|
return
|
||||||
}
|
}
|
||||||
if len(embeddings) != len(nonEmptyBatch) {
|
|
||||||
err := errors.New("embedding count mismatch")
|
// Send result with context awareness
|
||||||
r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Write vectors to storage
|
|
||||||
filename := path.Base(fpath)
|
|
||||||
for j, text := range nonEmptyBatch {
|
|
||||||
vector := models.VectorRow{
|
|
||||||
Embeddings: embeddings[j],
|
|
||||||
RawText: text,
|
|
||||||
Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j),
|
|
||||||
FileName: filename,
|
|
||||||
}
|
|
||||||
if err := r.storage.WriteVector(&vector); err != nil {
|
|
||||||
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
|
||||||
select {
|
|
||||||
case LongJobStatusCh <- ErrRAGStatus:
|
|
||||||
default:
|
|
||||||
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write vector: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch))
|
|
||||||
// Send progress status
|
|
||||||
statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize)
|
|
||||||
select {
|
select {
|
||||||
case LongJobStatusCh <- statusMsg:
|
case resultCh <- batchResult{
|
||||||
default:
|
batchIndex: task.batchIndex,
|
||||||
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
embeddings: embeddings,
|
||||||
|
paragraphs: task.paragraphs,
|
||||||
|
filename: task.filename,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
r.logger.Debug("embedding worker cancelled while sending result", "worker", workerID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.logger.Debug("worker completed batch", "worker", workerID, "batch", task.batchIndex, "embeddings", len(embeddings))
|
||||||
|
}
|
||||||
|
r.logger.Debug("embedding worker finished", "worker", workerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// embedWithRetry attempts embedding with exponential backoff for API embedder
|
||||||
|
func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
// Exponential backoff
|
||||||
|
backoff := time.Duration(attempt*attempt) * time.Second
|
||||||
|
if backoff > 10*time.Second {
|
||||||
|
backoff = 10 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(backoff):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings, err := r.embedder.EmbedSlice(paragraphs)
|
||||||
|
if err == nil {
|
||||||
|
// Validate embedding count
|
||||||
|
if len(embeddings) != len(paragraphs) {
|
||||||
|
return nil, fmt.Errorf("embedding count mismatch: expected %d, got %d", len(paragraphs), len(embeddings))
|
||||||
|
}
|
||||||
|
return embeddings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
// Only retry for API embedder errors (network/timeout)
|
||||||
|
// For ONNX embedder, fail fast
|
||||||
|
if _, isAPI := r.embedder.(*APIEmbedder); !isAPI {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.logger.Debug("finished writing vectors", "batches", batchCount)
|
|
||||||
r.resetIdleTimer()
|
return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr)
|
||||||
select {
|
}
|
||||||
case LongJobStatusCh <- FinishedRAGStatus:
|
|
||||||
default:
|
// writeBatchToStorage writes a single batch of vectors to the database
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filename string) error {
|
||||||
|
if len(result.embeddings) == 0 {
|
||||||
|
// Empty batch, skip
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check context before starting
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build all vectors for batch write
|
||||||
|
vectors := make([]*models.VectorRow, 0, len(result.paragraphs))
|
||||||
|
for j, text := range result.paragraphs {
|
||||||
|
vectors = append(vectors, &models.VectorRow{
|
||||||
|
Embeddings: result.embeddings[j],
|
||||||
|
RawText: text,
|
||||||
|
Slug: fmt.Sprintf("%s_%d_%d", filename, result.batchIndex+1, j),
|
||||||
|
FileName: filename,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write all vectors in a single transaction
|
||||||
|
if err := r.storage.WriteVectors(vectors); err != nil {
|
||||||
|
r.logger.Error("failed to write vectors batch to DB", "error", err, "batch", result.batchIndex+1, "size", len(vectors))
|
||||||
|
r.sendStatusNonBlocking(ErrRAGStatus)
|
||||||
|
return fmt.Errorf("failed to write vectors batch: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,22 +544,26 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
|||||||
return r.embedder.Embed(line)
|
return r.embedder.Embed(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
|
func (r *RAG) searchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
|
||||||
r.resetIdleTimer()
|
r.resetIdleTimer()
|
||||||
return r.storage.SearchClosest(emb.Embedding, limit)
|
return r.storage.SearchClosest(emb.Embedding, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) {
|
func (r *RAG) searchKeyword(query string, limit int) ([]models.VectorRow, error) {
|
||||||
r.resetIdleTimer()
|
r.resetIdleTimer()
|
||||||
sanitized := sanitizeFTSQuery(query)
|
sanitized := sanitizeFTSQuery(query)
|
||||||
return r.storage.SearchKeyword(sanitized, limit)
|
return r.storage.SearchKeyword(sanitized, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) ListLoaded() ([]string, error) {
|
func (r *RAG) ListLoaded() ([]string, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
return r.storage.ListFiles()
|
return r.storage.ListFiles()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) RemoveFile(filename string) error {
|
func (r *RAG) RemoveFile(filename string) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
r.resetIdleTimer()
|
r.resetIdleTimer()
|
||||||
return r.storage.RemoveEmbByFileName(filename)
|
return r.storage.RemoveEmbByFileName(filename)
|
||||||
}
|
}
|
||||||
@@ -454,6 +752,9 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
|
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
r.resetIdleTimer()
|
||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
return "No relevant information found in the vector database.", nil
|
return "No relevant information found in the vector database.", nil
|
||||||
}
|
}
|
||||||
@@ -482,7 +783,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
|
|||||||
Embedding: emb,
|
Embedding: emb,
|
||||||
Index: 0,
|
Index: 0,
|
||||||
}
|
}
|
||||||
topResults, err := r.SearchEmb(embResp, 1)
|
topResults, err := r.searchEmb(embResp, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to search for synthesis context", "error", err)
|
r.logger.Error("failed to search for synthesis context", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -509,6 +810,9 @@ func truncateString(s string, maxLen int) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
r.resetIdleTimer()
|
||||||
refined := r.RefineQuery(query)
|
refined := r.RefineQuery(query)
|
||||||
variations := r.GenerateQueryVariations(refined)
|
variations := r.GenerateQueryVariations(refined)
|
||||||
|
|
||||||
@@ -525,7 +829,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
|||||||
Embedding: emb,
|
Embedding: emb,
|
||||||
Index: 0,
|
Index: 0,
|
||||||
}
|
}
|
||||||
results, err := r.SearchEmb(embResp, limit*2) // Get more candidates
|
results, err := r.searchEmb(embResp, limit*2) // Get more candidates
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to search embeddings", "error", err, "query", q)
|
r.logger.Error("failed to search embeddings", "error", err, "query", q)
|
||||||
continue
|
continue
|
||||||
@@ -543,7 +847,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Perform keyword search
|
// Perform keyword search
|
||||||
kwResults, err := r.SearchKeyword(refined, limit*2)
|
kwResults, err := r.searchKeyword(refined, limit*2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
|
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
|
||||||
kwResults = nil
|
kwResults = nil
|
||||||
@@ -621,6 +925,8 @@ func GetInstance() *RAG {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) resetIdleTimer() {
|
func (r *RAG) resetIdleTimer() {
|
||||||
|
r.idleMu.Lock()
|
||||||
|
defer r.idleMu.Unlock()
|
||||||
if r.idleTimer != nil {
|
if r.idleTimer != nil {
|
||||||
r.idleTimer.Stop()
|
r.idleTimer.Stop()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,6 +102,92 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteVectors stores multiple embedding vectors in a single transaction
|
||||||
|
func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error {
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// SQLite has limit of 999 parameters per statement, each row uses 4 parameters
|
||||||
|
const maxBatchSize = 200 // 200 * 4 = 800 < 999
|
||||||
|
if len(rows) > maxBatchSize {
|
||||||
|
// Process in chunks
|
||||||
|
for i := 0; i < len(rows); i += maxBatchSize {
|
||||||
|
end := i + maxBatchSize
|
||||||
|
if end > len(rows) {
|
||||||
|
end = len(rows)
|
||||||
|
}
|
||||||
|
if err := vs.WriteVectors(rows[i:end]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// All rows should have same embedding size (same model)
|
||||||
|
firstSize := len(rows[0].Embeddings)
|
||||||
|
for i, row := range rows {
|
||||||
|
if len(row.Embeddings) != firstSize {
|
||||||
|
return fmt.Errorf("embedding size mismatch: row %d has size %d, expected %d", i, len(row.Embeddings), firstSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tableName, err := vs.getTableName(rows[0].Embeddings)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start transaction
|
||||||
|
tx, err := vs.sqlxDB.Beginx()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Build batch insert for embeddings table
|
||||||
|
embeddingPlaceholders := make([]string, 0, len(rows))
|
||||||
|
embeddingArgs := make([]any, 0, len(rows)*4)
|
||||||
|
for _, row := range rows {
|
||||||
|
embeddingPlaceholders = append(embeddingPlaceholders, "(?, ?, ?, ?)")
|
||||||
|
embeddingArgs = append(embeddingArgs, SerializeVector(row.Embeddings), row.Slug, row.RawText, row.FileName)
|
||||||
|
}
|
||||||
|
embeddingQuery := fmt.Sprintf(
|
||||||
|
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES %s",
|
||||||
|
tableName,
|
||||||
|
strings.Join(embeddingPlaceholders, ", "),
|
||||||
|
)
|
||||||
|
if _, err := tx.Exec(embeddingQuery, embeddingArgs...); err != nil {
|
||||||
|
vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build batch insert for FTS table
|
||||||
|
ftsPlaceholders := make([]string, 0, len(rows))
|
||||||
|
ftsArgs := make([]any, 0, len(rows)*4)
|
||||||
|
embeddingSize := len(rows[0].Embeddings)
|
||||||
|
for _, row := range rows {
|
||||||
|
ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)")
|
||||||
|
ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize)
|
||||||
|
}
|
||||||
|
ftsQuery := fmt.Sprintf(
|
||||||
|
"INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES %s",
|
||||||
|
strings.Join(ftsPlaceholders, ", "),
|
||||||
|
)
|
||||||
|
if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil {
|
||||||
|
vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
vs.logger.Error("failed to commit transaction", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
vs.logger.Debug("wrote vectors batch", "batch_size", len(rows))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getTableName determines which table to use based on embedding size
|
// getTableName determines which table to use based on embedding size
|
||||||
func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
|
func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
|
||||||
size := len(emb)
|
size := len(emb)
|
||||||
|
|||||||
@@ -102,6 +102,22 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
|
|||||||
logger.Error("failed to open db connection", "error", err)
|
logger.Error("failed to open db connection", "error", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// Enable WAL mode for better concurrency and performance
|
||||||
|
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
|
||||||
|
logger.Warn("failed to enable WAL mode", "error", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
|
||||||
|
logger.Warn("failed to set synchronous mode", "error", err)
|
||||||
|
}
|
||||||
|
// Increase cache size for better performance
|
||||||
|
if _, err := db.Exec("PRAGMA cache_size = -2000;"); err != nil {
|
||||||
|
logger.Warn("failed to set cache size", "error", err)
|
||||||
|
}
|
||||||
|
// Log actual journal mode for debugging
|
||||||
|
var journalMode string
|
||||||
|
if err := db.QueryRow("PRAGMA journal_mode;").Scan(&journalMode); err == nil {
|
||||||
|
logger.Debug("SQLite journal mode", "mode", journalMode)
|
||||||
|
}
|
||||||
p := ProviderSQL{db: db, logger: logger}
|
p := ProviderSQL{db: db, logger: logger}
|
||||||
if err := p.Migrate(); err != nil {
|
if err := p.Migrate(); err != nil {
|
||||||
logger.Error("migration failed, app cannot start", "error", err)
|
logger.Error("migration failed, app cannot start", "error", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user