Feat: RAG file loading status textview

This commit is contained in:
Grail Finder
2025-01-11 17:29:21 +03:00
parent f40d8afe08
commit 85f96aa401
3 changed files with 92 additions and 49 deletions

View File

@@ -18,6 +18,14 @@ import (
"github.com/neurosnap/sentences/english" "github.com/neurosnap/sentences/english"
) )
var (
LongJobStatusCh = make(chan string, 1)
// messages
FinishedRAGStatus = "finished loading RAG file; press Enter"
LoadedFileRAGStatus = "loaded file"
ErrRAGStatus = "some error occured; failed to transfer data to vector db"
)
type RAG struct { type RAG struct {
logger *slog.Logger logger *slog.Logger
store storage.FullRepo store storage.FullRepo
@@ -42,6 +50,7 @@ func (r *RAG) LoadRAG(fpath string) error {
return err return err
} }
r.logger.Info("rag: loaded file", "fp", fpath) r.logger.Info("rag: loaded file", "fp", fpath)
LongJobStatusCh <- LoadedFileRAGStatus
fileText := string(data) fileText := string(data)
tokenizer, err := english.NewSentenceTokenizer(nil) tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil { if err != nil {
@@ -49,7 +58,6 @@ func (r *RAG) LoadRAG(fpath string) error {
} }
sentences := tokenizer.Tokenize(fileText) sentences := tokenizer.Tokenize(fileText)
sents := make([]string, len(sentences)) sents := make([]string, len(sentences))
r.logger.Info("rag: sentences", "#", len(sents))
for i, s := range sentences { for i, s := range sentences {
sents[i] = s.Text sents[i] = s.Text
} }
@@ -60,9 +68,7 @@ func (r *RAG) LoadRAG(fpath string) error {
batchSize = 100 batchSize = 100
maxChSize = 1000 maxChSize = 1000
// //
// psize = 3
wordLimit = 80 wordLimit = 80
//
left = 0 left = 0
right = batchSize right = batchSize
batchCh = make(chan map[int][]string, maxChSize) batchCh = make(chan map[int][]string, maxChSize)
@@ -84,13 +90,6 @@ func (r *RAG) LoadRAG(fpath string) error {
par.Reset() par.Reset()
} }
} }
// for i := 0; i < len(sents); i += psize {
// if len(sents) < i+psize {
// paragraphs = append(paragraphs, strings.Join(sents[i:], " "))
// break
// }
// paragraphs = append(paragraphs, strings.Join(sents[i:i+psize], " "))
// }
if len(paragraphs) < batchSize { if len(paragraphs) < batchSize {
batchSize = len(paragraphs) batchSize = len(paragraphs)
} }
@@ -105,7 +104,9 @@ func (r *RAG) LoadRAG(fpath string) error {
left, right = right, right+batchSize left, right = right, right+batchSize
ctn++ ctn++
} }
r.logger.Info("finished batching", "batches#", len(batchCh), "paragraphs", len(paragraphs), "sentences", len(sents)) finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", len(batchCh), len(paragraphs), len(sents))
r.logger.Info(finishedBatchesMsg)
LongJobStatusCh <- finishedBatchesMsg
for w := 0; w < workers; w++ { for w := 0; w < workers; w++ {
go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
} }
@@ -121,6 +122,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
for _, vector := range batch { for _, vector := range batch {
if err := r.store.WriteVector(&vector); err != nil { if err := r.store.WriteVector(&vector); err != nil {
r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug) r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug)
LongJobStatusCh <- ErrRAGStatus
continue // a duplicate is not critical continue // a duplicate is not critical
// return err // return err
} }
@@ -128,6 +130,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
if len(vectorCh) == 0 { if len(vectorCh) == 0 {
r.logger.Info("finished writing vectors") r.logger.Info("finished writing vectors")
LongJobStatusCh <- FinishedRAGStatus
defer close(vectorCh) defer close(vectorCh)
return nil return nil
} }
@@ -150,10 +153,6 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[
case linesMap := <-inputCh: case linesMap := <-inputCh:
for leftI, v := range linesMap { for leftI, v := range linesMap {
r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename) r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename)
// if leftI+200 >= limit { // last batch
// // doneCh <- true
// return
// }
} }
lock.Unlock() lock.Unlock()
case err := <-errCh: case err := <-errCh:
@@ -162,6 +161,7 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[
return return
} }
r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id) r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id)
LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
} }
} }
@@ -183,8 +183,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
} }
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
// nolint
// resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload))
if err != nil { if err != nil {
r.logger.Error("failed to embedd line", "err:", err.Error()) r.logger.Error("failed to embedd line", "err:", err.Error())
errCh <- err errCh <- err
@@ -194,9 +192,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
r.logger.Error("non 200 resp", "code", resp.StatusCode) r.logger.Error("non 200 resp", "code", resp.StatusCode)
return return
// err = fmt.Errorf("non 200 resp; url: %s; code %d", r.cfg.EmbedURL, resp.StatusCode)
// errCh <- err
// return
} }
emb := [][]float32{} emb := [][]float32{}
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
@@ -224,7 +219,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
} }
func (r *RAG) LineToVector(line string) ([]float32, error) { func (r *RAG) LineToVector(line string) ([]float32, error) {
// payload, err := json.Marshal(map[string]string{"content": line})
lines := []string{line} lines := []string{line}
payload, err := json.Marshal( payload, err := json.Marshal(
map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}},
@@ -241,7 +235,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
} }
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
// resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload))
if err != nil { if err != nil {
r.logger.Error("failed to embedd line", "err:", err.Error()) r.logger.Error("failed to embedd line", "err:", err.Error())
return nil, err return nil, err
@@ -252,7 +245,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
r.logger.Error(err.Error()) r.logger.Error(err.Error())
return nil, err return nil, err
} }
// emb := models.EmbeddingResp{}
emb := [][]float32{} emb := [][]float32{}
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
r.logger.Error("failed to embedd line", "err:", err.Error()) r.logger.Error("failed to embedd line", "err:", err.Error())

View File

@@ -1,8 +1,12 @@
package main package main
import ( import (
"fmt"
"os" "os"
"path" "path"
"time"
"elefant/rag"
"github.com/gdamore/tcell/v2" "github.com/gdamore/tcell/v2"
"github.com/rivo/tview" "github.com/rivo/tview"
@@ -85,11 +89,21 @@ func makeChatTable(chatList []string) *tview.Table {
return chatActTable return chatActTable
} }
func makeRAGTable(fileList []string) *tview.Table { // func makeRAGTable(fileList []string) *tview.Table {
func makeRAGTable(fileList []string) *tview.Flex {
actions := []string{"load", "delete"} actions := []string{"load", "delete"}
rows, cols := len(fileList), len(actions)+1 rows, cols := len(fileList), len(actions)+1
fileTable := tview.NewTable(). fileTable := tview.NewTable().
SetBorders(true) SetBorders(true)
longStatusView := tview.NewTextView()
longStatusView.SetText("status text")
longStatusView.SetBorder(true).SetTitle("status")
longStatusView.SetChangedFunc(func() {
app.Draw()
})
ragflex := tview.NewFlex().SetDirection(tview.FlexRow).
AddItem(longStatusView, 0, 10, false).
AddItem(fileTable, 0, 60, true)
for r := 0; r < rows; r++ { for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ { for c := 0; c < cols; c++ {
color := tcell.ColorWhite color := tcell.ColorWhite
@@ -106,6 +120,33 @@ func makeRAGTable(fileList []string) *tview.Table {
} }
} }
} }
errCh := make(chan error, 1)
go func() {
defer pages.RemovePage(RAGPage)
for {
select {
case err := <-errCh:
if err == nil {
logger.Error("somehow got a nil err", "error", err)
continue
}
logger.Error("got an err in rag status", "error", err, "textview", longStatusView)
longStatusView.SetText(fmt.Sprintf("%v", err))
close(errCh)
return
case status := <-rag.LongJobStatusCh:
logger.Info("reading status channel", "status", status)
longStatusView.SetText(status)
// fmt.Fprintln(longStatusView, status)
// app.Sync()
if status == rag.FinishedRAGStatus {
close(errCh)
time.Sleep(2 * time.Second)
return
}
}
}
}()
fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
if key == tcell.KeyEsc || key == tcell.KeyF1 { if key == tcell.KeyEsc || key == tcell.KeyF1 {
pages.RemovePage(RAGPage) pages.RemovePage(RAGPage)
@@ -115,7 +156,7 @@ func makeRAGTable(fileList []string) *tview.Table {
fileTable.SetSelectable(true, true) fileTable.SetSelectable(true, true)
} }
}).SetSelectedFunc(func(row int, column int) { }).SetSelectedFunc(func(row int, column int) {
defer pages.RemovePage(RAGPage) // defer pages.RemovePage(RAGPage)
tc := fileTable.GetCell(row, column) tc := fileTable.GetCell(row, column)
tc.SetTextColor(tcell.ColorRed) tc.SetTextColor(tcell.ColorRed)
fileTable.SetSelectable(false, false) fileTable.SetSelectable(false, false)
@@ -124,14 +165,18 @@ func makeRAGTable(fileList []string) *tview.Table {
switch tc.Text { switch tc.Text {
case "load": case "load":
fpath = path.Join(cfg.RAGDir, fpath) fpath = path.Join(cfg.RAGDir, fpath)
longStatusView.SetText("clicked load")
go func() {
if err := ragger.LoadRAG(fpath); err != nil { if err := ragger.LoadRAG(fpath); err != nil {
logger.Error("failed to embed file", "chat", fpath, "error", err) logger.Error("failed to embed file", "chat", fpath, "error", err)
errCh <- err
// pages.RemovePage(RAGPage) // pages.RemovePage(RAGPage)
return return
} }
pages.RemovePage(RAGPage) }()
colorText() // make new page and write status updates to it
updateStatusLine() // colorText()
// updateStatusLine()
return return
case "delete": case "delete":
fpath = path.Join(cfg.RAGDir, fpath) fpath = path.Join(cfg.RAGDir, fpath)
@@ -148,7 +193,7 @@ func makeRAGTable(fileList []string) *tview.Table {
return return
} }
}) })
return fileTable return ragflex
} }
func makeLoadedRAGTable(fileList []string) *tview.Table { func makeLoadedRAGTable(fileList []string) *tview.Table {

8
tui.go
View File

@@ -3,6 +3,7 @@ package main
import ( import (
"elefant/models" "elefant/models"
"elefant/pngmeta" "elefant/pngmeta"
"elefant/rag"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@@ -26,6 +27,8 @@ var (
sysModal *tview.Modal sysModal *tview.Modal
indexPickWindow *tview.InputField indexPickWindow *tview.InputField
renameWindow *tview.InputField renameWindow *tview.InputField
//
longJobStatusCh = make(chan string, 1)
// pages // pages
historyPage = "historyPage" historyPage = "historyPage"
agentPage = "agentPage" agentPage = "agentPage"
@@ -34,6 +37,7 @@ var (
helpPage = "helpPage" helpPage = "helpPage"
renamePage = "renamePage" renamePage = "renamePage"
RAGPage = "RAGPage " RAGPage = "RAGPage "
longStatusPage = "longStatusPage"
// help text // help text
helpText = ` helpText = `
[yellow]Esc[white]: send msg [yellow]Esc[white]: send msg
@@ -155,6 +159,7 @@ func init() {
position = tview.NewTextView(). position = tview.NewTextView().
SetDynamicColors(true). SetDynamicColors(true).
SetTextAlign(tview.AlignCenter) SetTextAlign(tview.AlignCenter)
flex = tview.NewFlex().SetDirection(tview.FlexRow). flex = tview.NewFlex().SetDirection(tview.FlexRow).
AddItem(textView, 0, 40, false). AddItem(textView, 0, 40, false).
AddItem(textArea, 0, 10, true). AddItem(textArea, 0, 10, true).
@@ -466,6 +471,7 @@ func init() {
} }
fileList = append(fileList, f.Name()) fileList = append(fileList, f.Name())
} }
rag.LongJobStatusCh <- "first msg"
chatRAGTable := makeRAGTable(fileList) chatRAGTable := makeRAGTable(fileList)
pages.AddPage(RAGPage, chatRAGTable, true, true) pages.AddPage(RAGPage, chatRAGTable, true, true)
return nil return nil
@@ -482,7 +488,7 @@ func init() {
if strings.HasSuffix(prevText, nl) { if strings.HasSuffix(prevText, nl) {
nl = "" nl = ""
} }
if msgText != "" { if msgText != "" { // continue
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
nl, len(chatBody.Messages), cfg.UserRole, msgText) nl, len(chatBody.Messages), cfg.UserRole, msgText)
textArea.SetText("", true) textArea.SetText("", true)