Feat: RAG file loading status textview
This commit is contained in:
52
rag/main.go
52
rag/main.go
@@ -18,6 +18,14 @@ import (
|
||||
"github.com/neurosnap/sentences/english"
|
||||
)
|
||||
|
||||
var (
|
||||
LongJobStatusCh = make(chan string, 1)
|
||||
// messages
|
||||
FinishedRAGStatus = "finished loading RAG file; press Enter"
|
||||
LoadedFileRAGStatus = "loaded file"
|
||||
ErrRAGStatus = "some error occured; failed to transfer data to vector db"
|
||||
)
|
||||
|
||||
type RAG struct {
|
||||
logger *slog.Logger
|
||||
store storage.FullRepo
|
||||
@@ -42,6 +50,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
return err
|
||||
}
|
||||
r.logger.Info("rag: loaded file", "fp", fpath)
|
||||
LongJobStatusCh <- LoadedFileRAGStatus
|
||||
fileText := string(data)
|
||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
||||
if err != nil {
|
||||
@@ -49,7 +58,6 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
}
|
||||
sentences := tokenizer.Tokenize(fileText)
|
||||
sents := make([]string, len(sentences))
|
||||
r.logger.Info("rag: sentences", "#", len(sents))
|
||||
for i, s := range sentences {
|
||||
sents[i] = s.Text
|
||||
}
|
||||
@@ -60,16 +68,14 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
batchSize = 100
|
||||
maxChSize = 1000
|
||||
//
|
||||
// psize = 3
|
||||
wordLimit = 80
|
||||
//
|
||||
left = 0
|
||||
right = batchSize
|
||||
batchCh = make(chan map[int][]string, maxChSize)
|
||||
vectorCh = make(chan []models.VectorRow, maxChSize)
|
||||
errCh = make(chan error, 1)
|
||||
doneCh = make(chan bool, 1)
|
||||
lock = new(sync.Mutex)
|
||||
left = 0
|
||||
right = batchSize
|
||||
batchCh = make(chan map[int][]string, maxChSize)
|
||||
vectorCh = make(chan []models.VectorRow, maxChSize)
|
||||
errCh = make(chan error, 1)
|
||||
doneCh = make(chan bool, 1)
|
||||
lock = new(sync.Mutex)
|
||||
)
|
||||
defer close(doneCh)
|
||||
defer close(errCh)
|
||||
@@ -84,13 +90,6 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
par.Reset()
|
||||
}
|
||||
}
|
||||
// for i := 0; i < len(sents); i += psize {
|
||||
// if len(sents) < i+psize {
|
||||
// paragraphs = append(paragraphs, strings.Join(sents[i:], " "))
|
||||
// break
|
||||
// }
|
||||
// paragraphs = append(paragraphs, strings.Join(sents[i:i+psize], " "))
|
||||
// }
|
||||
if len(paragraphs) < batchSize {
|
||||
batchSize = len(paragraphs)
|
||||
}
|
||||
@@ -105,7 +104,9 @@ func (r *RAG) LoadRAG(fpath string) error {
|
||||
left, right = right, right+batchSize
|
||||
ctn++
|
||||
}
|
||||
r.logger.Info("finished batching", "batches#", len(batchCh), "paragraphs", len(paragraphs), "sentences", len(sents))
|
||||
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", len(batchCh), len(paragraphs), len(sents))
|
||||
r.logger.Info(finishedBatchesMsg)
|
||||
LongJobStatusCh <- finishedBatchesMsg
|
||||
for w := 0; w < workers; w++ {
|
||||
go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
||||
}
|
||||
@@ -121,6 +122,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
||||
for _, vector := range batch {
|
||||
if err := r.store.WriteVector(&vector); err != nil {
|
||||
r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug)
|
||||
LongJobStatusCh <- ErrRAGStatus
|
||||
continue // a duplicate is not critical
|
||||
// return err
|
||||
}
|
||||
@@ -128,6 +130,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
||||
r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
||||
if len(vectorCh) == 0 {
|
||||
r.logger.Info("finished writing vectors")
|
||||
LongJobStatusCh <- FinishedRAGStatus
|
||||
defer close(vectorCh)
|
||||
return nil
|
||||
}
|
||||
@@ -150,10 +153,6 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[
|
||||
case linesMap := <-inputCh:
|
||||
for leftI, v := range linesMap {
|
||||
r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename)
|
||||
// if leftI+200 >= limit { // last batch
|
||||
// // doneCh <- true
|
||||
// return
|
||||
// }
|
||||
}
|
||||
lock.Unlock()
|
||||
case err := <-errCh:
|
||||
@@ -162,6 +161,7 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[
|
||||
return
|
||||
}
|
||||
r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id)
|
||||
LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,8 +183,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
|
||||
}
|
||||
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
// nolint
|
||||
// resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
r.logger.Error("failed to embedd line", "err:", err.Error())
|
||||
errCh <- err
|
||||
@@ -194,9 +192,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
|
||||
if resp.StatusCode != 200 {
|
||||
r.logger.Error("non 200 resp", "code", resp.StatusCode)
|
||||
return
|
||||
// err = fmt.Errorf("non 200 resp; url: %s; code %d", r.cfg.EmbedURL, resp.StatusCode)
|
||||
// errCh <- err
|
||||
// return
|
||||
}
|
||||
emb := [][]float32{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
||||
@@ -224,7 +219,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
|
||||
}
|
||||
|
||||
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||
// payload, err := json.Marshal(map[string]string{"content": line})
|
||||
lines := []string{line}
|
||||
payload, err := json.Marshal(
|
||||
map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}},
|
||||
@@ -241,7 +235,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||
}
|
||||
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
// resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
r.logger.Error("failed to embedd line", "err:", err.Error())
|
||||
return nil, err
|
||||
@@ -252,7 +245,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||
r.logger.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
// emb := models.EmbeddingResp{}
|
||||
emb := [][]float32{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
||||
r.logger.Error("failed to embedd line", "err:", err.Error())
|
||||
|
||||
Reference in New Issue
Block a user