Feat: RAG file loading status textview
This commit is contained in:
38
rag/main.go
38
rag/main.go
@@ -18,6 +18,14 @@ import (
|
|||||||
"github.com/neurosnap/sentences/english"
|
"github.com/neurosnap/sentences/english"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
LongJobStatusCh = make(chan string, 1)
|
||||||
|
// messages
|
||||||
|
FinishedRAGStatus = "finished loading RAG file; press Enter"
|
||||||
|
LoadedFileRAGStatus = "loaded file"
|
||||||
|
ErrRAGStatus = "some error occured; failed to transfer data to vector db"
|
||||||
|
)
|
||||||
|
|
||||||
type RAG struct {
|
type RAG struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
store storage.FullRepo
|
store storage.FullRepo
|
||||||
@@ -42,6 +50,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.logger.Info("rag: loaded file", "fp", fpath)
|
r.logger.Info("rag: loaded file", "fp", fpath)
|
||||||
|
LongJobStatusCh <- LoadedFileRAGStatus
|
||||||
fileText := string(data)
|
fileText := string(data)
|
||||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
tokenizer, err := english.NewSentenceTokenizer(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -49,7 +58,6 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
}
|
}
|
||||||
sentences := tokenizer.Tokenize(fileText)
|
sentences := tokenizer.Tokenize(fileText)
|
||||||
sents := make([]string, len(sentences))
|
sents := make([]string, len(sentences))
|
||||||
r.logger.Info("rag: sentences", "#", len(sents))
|
|
||||||
for i, s := range sentences {
|
for i, s := range sentences {
|
||||||
sents[i] = s.Text
|
sents[i] = s.Text
|
||||||
}
|
}
|
||||||
@@ -60,9 +68,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
batchSize = 100
|
batchSize = 100
|
||||||
maxChSize = 1000
|
maxChSize = 1000
|
||||||
//
|
//
|
||||||
// psize = 3
|
|
||||||
wordLimit = 80
|
wordLimit = 80
|
||||||
//
|
|
||||||
left = 0
|
left = 0
|
||||||
right = batchSize
|
right = batchSize
|
||||||
batchCh = make(chan map[int][]string, maxChSize)
|
batchCh = make(chan map[int][]string, maxChSize)
|
||||||
@@ -84,13 +90,6 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
par.Reset()
|
par.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// for i := 0; i < len(sents); i += psize {
|
|
||||||
// if len(sents) < i+psize {
|
|
||||||
// paragraphs = append(paragraphs, strings.Join(sents[i:], " "))
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// paragraphs = append(paragraphs, strings.Join(sents[i:i+psize], " "))
|
|
||||||
// }
|
|
||||||
if len(paragraphs) < batchSize {
|
if len(paragraphs) < batchSize {
|
||||||
batchSize = len(paragraphs)
|
batchSize = len(paragraphs)
|
||||||
}
|
}
|
||||||
@@ -105,7 +104,9 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
left, right = right, right+batchSize
|
left, right = right, right+batchSize
|
||||||
ctn++
|
ctn++
|
||||||
}
|
}
|
||||||
r.logger.Info("finished batching", "batches#", len(batchCh), "paragraphs", len(paragraphs), "sentences", len(sents))
|
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", len(batchCh), len(paragraphs), len(sents))
|
||||||
|
r.logger.Info(finishedBatchesMsg)
|
||||||
|
LongJobStatusCh <- finishedBatchesMsg
|
||||||
for w := 0; w < workers; w++ {
|
for w := 0; w < workers; w++ {
|
||||||
go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
|
||||||
}
|
}
|
||||||
@@ -121,6 +122,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
|||||||
for _, vector := range batch {
|
for _, vector := range batch {
|
||||||
if err := r.store.WriteVector(&vector); err != nil {
|
if err := r.store.WriteVector(&vector); err != nil {
|
||||||
r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug)
|
r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug)
|
||||||
|
LongJobStatusCh <- ErrRAGStatus
|
||||||
continue // a duplicate is not critical
|
continue // a duplicate is not critical
|
||||||
// return err
|
// return err
|
||||||
}
|
}
|
||||||
@@ -128,6 +130,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
|||||||
r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
||||||
if len(vectorCh) == 0 {
|
if len(vectorCh) == 0 {
|
||||||
r.logger.Info("finished writing vectors")
|
r.logger.Info("finished writing vectors")
|
||||||
|
LongJobStatusCh <- FinishedRAGStatus
|
||||||
defer close(vectorCh)
|
defer close(vectorCh)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -150,10 +153,6 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[
|
|||||||
case linesMap := <-inputCh:
|
case linesMap := <-inputCh:
|
||||||
for leftI, v := range linesMap {
|
for leftI, v := range linesMap {
|
||||||
r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename)
|
r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename)
|
||||||
// if leftI+200 >= limit { // last batch
|
|
||||||
// // doneCh <- true
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
lock.Unlock()
|
lock.Unlock()
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
@@ -162,6 +161,7 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id)
|
r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id)
|
||||||
|
LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,8 +183,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
|
|||||||
}
|
}
|
||||||
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
|
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
// nolint
|
|
||||||
// resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to embedd line", "err:", err.Error())
|
r.logger.Error("failed to embedd line", "err:", err.Error())
|
||||||
errCh <- err
|
errCh <- err
|
||||||
@@ -194,9 +192,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
|
|||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
r.logger.Error("non 200 resp", "code", resp.StatusCode)
|
r.logger.Error("non 200 resp", "code", resp.StatusCode)
|
||||||
return
|
return
|
||||||
// err = fmt.Errorf("non 200 resp; url: %s; code %d", r.cfg.EmbedURL, resp.StatusCode)
|
|
||||||
// errCh <- err
|
|
||||||
// return
|
|
||||||
}
|
}
|
||||||
emb := [][]float32{}
|
emb := [][]float32{}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
||||||
@@ -224,7 +219,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||||
// payload, err := json.Marshal(map[string]string{"content": line})
|
|
||||||
lines := []string{line}
|
lines := []string{line}
|
||||||
payload, err := json.Marshal(
|
payload, err := json.Marshal(
|
||||||
map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}},
|
map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}},
|
||||||
@@ -241,7 +235,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
|||||||
}
|
}
|
||||||
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
|
req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
// resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to embedd line", "err:", err.Error())
|
r.logger.Error("failed to embedd line", "err:", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -252,7 +245,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
|
|||||||
r.logger.Error(err.Error())
|
r.logger.Error(err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// emb := models.EmbeddingResp{}
|
|
||||||
emb := [][]float32{}
|
emb := [][]float32{}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
|
||||||
r.logger.Error("failed to embedd line", "err:", err.Error())
|
r.logger.Error("failed to embedd line", "err:", err.Error())
|
||||||
|
|||||||
57
tables.go
57
tables.go
@@ -1,8 +1,12 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"elefant/rag"
|
||||||
|
|
||||||
"github.com/gdamore/tcell/v2"
|
"github.com/gdamore/tcell/v2"
|
||||||
"github.com/rivo/tview"
|
"github.com/rivo/tview"
|
||||||
@@ -85,11 +89,21 @@ func makeChatTable(chatList []string) *tview.Table {
|
|||||||
return chatActTable
|
return chatActTable
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRAGTable(fileList []string) *tview.Table {
|
// func makeRAGTable(fileList []string) *tview.Table {
|
||||||
|
func makeRAGTable(fileList []string) *tview.Flex {
|
||||||
actions := []string{"load", "delete"}
|
actions := []string{"load", "delete"}
|
||||||
rows, cols := len(fileList), len(actions)+1
|
rows, cols := len(fileList), len(actions)+1
|
||||||
fileTable := tview.NewTable().
|
fileTable := tview.NewTable().
|
||||||
SetBorders(true)
|
SetBorders(true)
|
||||||
|
longStatusView := tview.NewTextView()
|
||||||
|
longStatusView.SetText("status text")
|
||||||
|
longStatusView.SetBorder(true).SetTitle("status")
|
||||||
|
longStatusView.SetChangedFunc(func() {
|
||||||
|
app.Draw()
|
||||||
|
})
|
||||||
|
ragflex := tview.NewFlex().SetDirection(tview.FlexRow).
|
||||||
|
AddItem(longStatusView, 0, 10, false).
|
||||||
|
AddItem(fileTable, 0, 60, true)
|
||||||
for r := 0; r < rows; r++ {
|
for r := 0; r < rows; r++ {
|
||||||
for c := 0; c < cols; c++ {
|
for c := 0; c < cols; c++ {
|
||||||
color := tcell.ColorWhite
|
color := tcell.ColorWhite
|
||||||
@@ -106,6 +120,33 @@ func makeRAGTable(fileList []string) *tview.Table {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer pages.RemovePage(RAGPage)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if err == nil {
|
||||||
|
logger.Error("somehow got a nil err", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Error("got an err in rag status", "error", err, "textview", longStatusView)
|
||||||
|
longStatusView.SetText(fmt.Sprintf("%v", err))
|
||||||
|
close(errCh)
|
||||||
|
return
|
||||||
|
case status := <-rag.LongJobStatusCh:
|
||||||
|
logger.Info("reading status channel", "status", status)
|
||||||
|
longStatusView.SetText(status)
|
||||||
|
// fmt.Fprintln(longStatusView, status)
|
||||||
|
// app.Sync()
|
||||||
|
if status == rag.FinishedRAGStatus {
|
||||||
|
close(errCh)
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
|
fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
|
||||||
if key == tcell.KeyEsc || key == tcell.KeyF1 {
|
if key == tcell.KeyEsc || key == tcell.KeyF1 {
|
||||||
pages.RemovePage(RAGPage)
|
pages.RemovePage(RAGPage)
|
||||||
@@ -115,7 +156,7 @@ func makeRAGTable(fileList []string) *tview.Table {
|
|||||||
fileTable.SetSelectable(true, true)
|
fileTable.SetSelectable(true, true)
|
||||||
}
|
}
|
||||||
}).SetSelectedFunc(func(row int, column int) {
|
}).SetSelectedFunc(func(row int, column int) {
|
||||||
defer pages.RemovePage(RAGPage)
|
// defer pages.RemovePage(RAGPage)
|
||||||
tc := fileTable.GetCell(row, column)
|
tc := fileTable.GetCell(row, column)
|
||||||
tc.SetTextColor(tcell.ColorRed)
|
tc.SetTextColor(tcell.ColorRed)
|
||||||
fileTable.SetSelectable(false, false)
|
fileTable.SetSelectable(false, false)
|
||||||
@@ -124,14 +165,18 @@ func makeRAGTable(fileList []string) *tview.Table {
|
|||||||
switch tc.Text {
|
switch tc.Text {
|
||||||
case "load":
|
case "load":
|
||||||
fpath = path.Join(cfg.RAGDir, fpath)
|
fpath = path.Join(cfg.RAGDir, fpath)
|
||||||
|
longStatusView.SetText("clicked load")
|
||||||
|
go func() {
|
||||||
if err := ragger.LoadRAG(fpath); err != nil {
|
if err := ragger.LoadRAG(fpath); err != nil {
|
||||||
logger.Error("failed to embed file", "chat", fpath, "error", err)
|
logger.Error("failed to embed file", "chat", fpath, "error", err)
|
||||||
|
errCh <- err
|
||||||
// pages.RemovePage(RAGPage)
|
// pages.RemovePage(RAGPage)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pages.RemovePage(RAGPage)
|
}()
|
||||||
colorText()
|
// make new page and write status updates to it
|
||||||
updateStatusLine()
|
// colorText()
|
||||||
|
// updateStatusLine()
|
||||||
return
|
return
|
||||||
case "delete":
|
case "delete":
|
||||||
fpath = path.Join(cfg.RAGDir, fpath)
|
fpath = path.Join(cfg.RAGDir, fpath)
|
||||||
@@ -148,7 +193,7 @@ func makeRAGTable(fileList []string) *tview.Table {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return fileTable
|
return ragflex
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeLoadedRAGTable(fileList []string) *tview.Table {
|
func makeLoadedRAGTable(fileList []string) *tview.Table {
|
||||||
|
|||||||
8
tui.go
8
tui.go
@@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"elefant/models"
|
"elefant/models"
|
||||||
"elefant/pngmeta"
|
"elefant/pngmeta"
|
||||||
|
"elefant/rag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -26,6 +27,8 @@ var (
|
|||||||
sysModal *tview.Modal
|
sysModal *tview.Modal
|
||||||
indexPickWindow *tview.InputField
|
indexPickWindow *tview.InputField
|
||||||
renameWindow *tview.InputField
|
renameWindow *tview.InputField
|
||||||
|
//
|
||||||
|
longJobStatusCh = make(chan string, 1)
|
||||||
// pages
|
// pages
|
||||||
historyPage = "historyPage"
|
historyPage = "historyPage"
|
||||||
agentPage = "agentPage"
|
agentPage = "agentPage"
|
||||||
@@ -34,6 +37,7 @@ var (
|
|||||||
helpPage = "helpPage"
|
helpPage = "helpPage"
|
||||||
renamePage = "renamePage"
|
renamePage = "renamePage"
|
||||||
RAGPage = "RAGPage "
|
RAGPage = "RAGPage "
|
||||||
|
longStatusPage = "longStatusPage"
|
||||||
// help text
|
// help text
|
||||||
helpText = `
|
helpText = `
|
||||||
[yellow]Esc[white]: send msg
|
[yellow]Esc[white]: send msg
|
||||||
@@ -155,6 +159,7 @@ func init() {
|
|||||||
position = tview.NewTextView().
|
position = tview.NewTextView().
|
||||||
SetDynamicColors(true).
|
SetDynamicColors(true).
|
||||||
SetTextAlign(tview.AlignCenter)
|
SetTextAlign(tview.AlignCenter)
|
||||||
|
|
||||||
flex = tview.NewFlex().SetDirection(tview.FlexRow).
|
flex = tview.NewFlex().SetDirection(tview.FlexRow).
|
||||||
AddItem(textView, 0, 40, false).
|
AddItem(textView, 0, 40, false).
|
||||||
AddItem(textArea, 0, 10, true).
|
AddItem(textArea, 0, 10, true).
|
||||||
@@ -466,6 +471,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
fileList = append(fileList, f.Name())
|
fileList = append(fileList, f.Name())
|
||||||
}
|
}
|
||||||
|
rag.LongJobStatusCh <- "first msg"
|
||||||
chatRAGTable := makeRAGTable(fileList)
|
chatRAGTable := makeRAGTable(fileList)
|
||||||
pages.AddPage(RAGPage, chatRAGTable, true, true)
|
pages.AddPage(RAGPage, chatRAGTable, true, true)
|
||||||
return nil
|
return nil
|
||||||
@@ -482,7 +488,7 @@ func init() {
|
|||||||
if strings.HasSuffix(prevText, nl) {
|
if strings.HasSuffix(prevText, nl) {
|
||||||
nl = ""
|
nl = ""
|
||||||
}
|
}
|
||||||
if msgText != "" {
|
if msgText != "" { // continue
|
||||||
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
|
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
|
||||||
nl, len(chatBody.Messages), cfg.UserRole, msgText)
|
nl, len(chatBody.Messages), cfg.UserRole, msgText)
|
||||||
textArea.SetText("", true)
|
textArea.SetText("", true)
|
||||||
|
|||||||
Reference in New Issue
Block a user