Feat (RAG): tying tui calls to rag funcs [WIP; skip-ci]

RAG itself is annoying to properly implement, plucking sentences with no
context is useless. Also it should not be a part of main package, same
for goes for tui. The number of global vars is absurd.
This commit is contained in:
Grail Finder
2025-01-04 18:13:13 +03:00
parent 461d19aa25
commit 4736e43631
13 changed files with 404 additions and 18 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ history/
*.db
config.toml
sysprompts/*
history_bak/

View File

@@ -32,6 +32,7 @@
- it is a bit clumsy to mix chats in db and chars from the external files, maybe load external files in db on startup?
- lets say we have two (or more) agents with the same name across multiple chats. These agents go and ask db for topics they memorised. Now they can access topics that aren't meant for them. (so memory should have an option: shareable; that indicates if that memory can be shared across chats);
- delete chat option;
- server mode: no tui but api calls with the func calling, rag, other middleware;
### FIX:
- bot responding (or hanging) blocks everything; +

46
bot.go
View File

@@ -15,6 +15,7 @@ import (
"strings"
"time"
"github.com/neurosnap/sentences/english"
"github.com/rivo/tview"
)
@@ -40,6 +41,16 @@ func formMsg(chatBody *models.ChatBody, newMsg, role string) io.Reader {
if newMsg != "" { // otherwise let the bot continue
newMsg := models.RoleMsg{Role: role, Content: newMsg}
chatBody.Messages = append(chatBody.Messages, newMsg)
// if rag
if cfg.RAGEnabled {
ragResp, err := chatRagUse(newMsg.Content)
if err != nil {
logger.Error("failed to form a rag msg", "error", err)
return nil
}
ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
}
}
data, err := json.Marshal(chatBody)
if err != nil {
@@ -107,6 +118,40 @@ func sendMsgToLLM(body io.Reader) {
}
}
func chatRagUse(qText string) (string, error) {
tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil {
return "", err
}
// TODO: this where llm should find the questions in text and ask them
questionsS := tokenizer.Tokenize(qText)
questions := make([]string, len(questionsS))
for i, q := range questionsS {
questions[i] = q.Text
}
respVecs := []*models.VectorRow{}
for i, q := range questions {
emb, err := lineToVector(q)
if err != nil {
logger.Error("failed to get embs", "error", err, "index", i, "question", q)
continue
}
vec, err := searchEmb(emb)
if err != nil {
logger.Error("failed to get embs", "error", err, "index", i, "question", q)
continue
}
respVecs = append(respVecs, vec)
// logger.Info("returned vector from query search", "question", q, "vec", vec)
}
// get raw text
resps := []string{}
for _, rv := range respVecs {
resps = append(resps, rv.RawText)
}
return strings.Join(resps, "\n"), nil
}
func chatRound(userMsg, role string, tv *tview.TextView, regen bool) {
botRespMode = true
reader := formMsg(chatBody, userMsg, role)
@@ -294,4 +339,5 @@ func init() {
Stream: true,
Messages: lastChat,
}
// tempLoad()
}

View File

@@ -1,4 +1,5 @@
APIURL = "http://localhost:8080/v1/chat/completions"
EmbedURL = "http://localhost:8080/v1/embeddings"
ShowSys = true
LogFile = "log.txt"
UserRole = "user"

View File

@@ -8,7 +8,6 @@ import (
type Config struct {
APIURL string `toml:"APIURL"`
EmbedURL string `toml:"EmbedURL"`
ShowSys bool `toml:"ShowSys"`
LogFile string `toml:"LogFile"`
UserRole string `toml:"UserRole"`
@@ -19,6 +18,11 @@ type Config struct {
ToolIcon string `toml:"ToolIcon"`
SysDir string `toml:"SysDir"`
ChunkLimit uint32 `toml:"ChunkLimit"`
// embeddings
RAGEnabled bool `toml:"RAGEnabled"`
EmbedURL string `toml:"EmbedURL"`
HFToken string `toml:"HFToken"`
RAGDir string `toml:"RAGDir"`
}
func LoadConfigOrDefault(fn string) *Config {
@@ -30,6 +34,7 @@ func LoadConfigOrDefault(fn string) *Config {
if err != nil {
fmt.Println("failed to read config from file, loading default")
config.APIURL = "http://localhost:8080/v1/chat/completions"
config.RAGEnabled = false
config.EmbedURL = "http://localhost:8080/v1/embiddings"
config.ShowSys = true
config.LogFile = "log.txt"

1
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/glebarez/go-sqlite v1.22.0
github.com/jmoiron/sqlx v1.4.0
github.com/ncruces/go-sqlite3 v0.21.3
github.com/neurosnap/sentences v1.1.2
github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592
)

2
go.sum
View File

@@ -34,6 +34,8 @@ github.com/ncruces/go-sqlite3 v0.21.3 h1:hHkfNQLcbnxPJZhC/RGw9SwP3bfkv/Y0xUHWsr1
github.com/ncruces/go-sqlite3 v0.21.3/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/neurosnap/sentences v1.1.2 h1:iphYOzx/XckXeBiLIUBkPu2EKMJ+6jDbz/sLJZ7ZoUw=
github.com/neurosnap/sentences v1.1.2/go.mod h1:/pwU4E9XNL21ygMIkOIllv/SMy2ujHwpf8GQPu1YPbQ=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 h1:YIJ+B1hePP6AgynC5TcqpO0H9k3SSoZa2BGyL6vDUzM=

View File

@@ -10,7 +10,7 @@ var (
botRespMode = false
editMode = false
selectedIndex = int(-1)
indexLine = "F12 to show keys help; bot resp mode: %v; char: %s; chat: %s"
indexLine = "F12 to show keys help; bot resp mode: %v; char: %s; chat: %s; RAGEnabled: %v"
focusSwitcher = map[tview.Primitive]tview.Primitive{}
)

View File

@@ -103,6 +103,9 @@ func ReadDirCards(dirname, uname string) ([]*models.CharCard, error) {
}
resp := []*models.CharCard{}
for _, f := range files {
if f.IsDir() {
continue
}
if strings.HasSuffix(f.Name(), ".png") {
fpath := path.Join(dirname, f.Name())
cc, err := ReadCard(fpath, uname)

184
rag.go
View File

@@ -2,27 +2,209 @@ package main
import (
"bytes"
"context"
"elefant/models"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"github.com/neurosnap/sentences/english"
)
func loadRAG(fpath string) error {
data, err := os.ReadFile(fpath)
if err != nil {
return err
}
fileText := string(data)
tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil {
return err
}
sentences := tokenizer.Tokenize(fileText)
sents := make([]string, len(sentences))
for i, s := range sentences {
sents[i] = s.Text
}
var (
// TODO: to config
workers = 5
batchSize = 200
//
left = 0
right = batchSize
batchCh = make(chan map[int][]string)
vectorCh = make(chan []models.VectorRow)
errCh = make(chan error)
)
if len(sents) < batchSize {
batchSize = len(sents)
}
// fill input channel
for {
if right > len(sents) {
batchCh <- map[int][]string{left: sents[left:]}
break
}
batchCh <- map[int][]string{left: sents[left:right]}
left, right = right, right+batchSize
}
// TODO: cancel complains, replace ctx with done chan
ctx, cancel := context.WithCancel(context.Background())
for w := 0; w < workers; w++ {
go batchToVectorHFAsync(ctx, cancel, len(sents), batchCh, vectorCh, errCh)
}
// write to db
return writeVectors(vectorCh)
}
func writeVectors(vectorCh <-chan []models.VectorRow) error {
for batch := range vectorCh {
for _, vector := range batch {
if err := store.WriteVector(&vector); err != nil {
return err
}
}
}
return nil
}
func batchToVectorHFAsync(ctx context.Context, close context.CancelFunc, limit int,
inputCh <-chan map[int][]string, vectorCh chan<- []models.VectorRow, errCh chan error) {
for {
select {
case linesMap := <-inputCh:
for leftI, v := range linesMap {
FecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI))
if leftI+200 >= limit { // last batch
close()
return
}
}
case <-ctx.Done():
logger.Error("got ctx done")
return
case err := <-errCh:
logger.Error("got an error", "error", err)
close()
return
}
}
}
func FecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) {
payload, err := json.Marshal(
map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}},
)
if err != nil {
logger.Error("failed to marshal payload", "err:", err.Error())
errCh <- err
return
}
req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken))
resp, err := httpClient.Do(req)
// nolint
// resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload))
if err != nil {
logger.Error("failed to embedd line", "err:", err.Error())
errCh <- err
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
logger.Error("non 200 resp", "code", resp.StatusCode)
errCh <- err
return
}
emb := [][]float32{}
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
logger.Error("failed to embedd line", "err:", err.Error())
errCh <- err
return
}
if len(emb) == 0 {
logger.Error("empty emb")
err = errors.New("empty emb")
errCh <- err
return
}
vectors := make([]models.VectorRow, len(emb))
for i, e := range emb {
vector := models.VectorRow{
Embeddings: e,
RawText: lines[i],
Slug: slug,
}
vectors[i] = vector
}
vectorCh <- vectors
}
func batchToVectorHF(lines []string) ([][]float32, error) {
payload, err := json.Marshal(
map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}},
)
if err != nil {
logger.Error("failed to marshal payload", "err:", err.Error())
return nil, err
}
req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken))
resp, err := httpClient.Do(req)
// nolint
// resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload))
if err != nil {
logger.Error("failed to embedd line", "err:", err.Error())
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
logger.Error("non 200 resp", "code", resp.StatusCode)
return nil, err
}
emb := [][]float32{}
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
logger.Error("failed to embedd line", "err:", err.Error())
return nil, err
}
if len(emb) == 0 {
logger.Error("empty emb")
err = errors.New("empty emb")
return nil, err
}
return emb, nil
}
func lineToVector(line string) (*models.EmbeddingResp, error) {
payload, err := json.Marshal(map[string]string{"content": line})
if err != nil {
logger.Error("failed to marshal payload", "err:", err.Error())
return nil, err
}
// nolint
resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload))
if err != nil {
logger.Error("failed to embedd line", "err:", err.Error())
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
logger.Error("non 200 resp", "code", resp.StatusCode)
return nil, err
}
emb := models.EmbeddingResp{}
if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
logger.Error("failed to embedd line", "err:", err.Error())
return nil, err
}
if len(emb.Embedding) == 0 {
logger.Error("empty emb")
err = errors.New("empty emb")
return nil, err
}
return &emb, nil
}
@@ -36,5 +218,5 @@ func saveLine(topic, line string, emb *models.EmbeddingResp) error {
}
func searchEmb(emb *models.EmbeddingResp) (*models.VectorRow, error) {
return store.SearchClosest([5120]float32(emb.Embedding))
return store.SearchClosest(emb.Embedding)
}

View File

@@ -4,3 +4,10 @@ CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0(
slug TEXT NOT NULL,
raw_text TEXT NOT NULL
);
CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0(
id INTEGER PRIMARY KEY AUTOINCREMENT,
embedding FLOAT[384],
slug TEXT NOT NULL,
raw_text TEXT NOT NULL
);

View File

@@ -2,6 +2,7 @@ package storage
import (
"elefant/models"
"errors"
"fmt"
"log"
"unsafe"
@@ -11,29 +12,61 @@ import (
type VectorRepo interface {
WriteVector(*models.VectorRow) error
SearchClosest(q [5120]float32) (*models.VectorRow, error)
SearchClosest(q []float32) (*models.VectorRow, error)
}
var vecTableName = "embeddings"
var (
vecTableName = "embeddings"
vecTableName384 = "embeddings_384"
)
func fetchTableName(emb []float32) (string, error) {
switch len(emb) {
case 5120:
return vecTableName, nil
case 384:
return vecTableName384, nil
default:
return "", fmt.Errorf("no table for the size of %d", len(emb))
}
}
func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
tableName, err := fetchTableName(row.Embeddings)
if err != nil {
return err
}
stmt, _, err := p.s3Conn.Prepare(
fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", vecTableName))
defer stmt.Close()
fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", tableName))
if err != nil {
p.logger.Error("failed to prep a stmt", "error", err)
return err
}
defer stmt.Close()
v, err := sqlite_vec.SerializeFloat32(row.Embeddings)
if err != nil {
p.logger.Error("failed to serialize vector",
"emb-len", len(row.Embeddings), "error", err)
return err
}
stmt.BindInt(1, int(row.ID))
stmt.BindBlob(2, v)
stmt.BindText(3, row.Slug)
stmt.BindText(4, row.RawText)
if v == nil {
err = errors.New("empty vector after serialization")
p.logger.Error("empty vector after serialization",
"emb-len", len(row.Embeddings), "text", row.RawText, "error", err)
return err
}
if err := stmt.BindBlob(1, v); err != nil {
p.logger.Error("failed to bind", "error", err)
return err
}
if err := stmt.BindText(2, row.Slug); err != nil {
p.logger.Error("failed to bind", "error", err)
return err
}
if err := stmt.BindText(3, row.RawText); err != nil {
p.logger.Error("failed to bind", "error", err)
return err
}
err = stmt.Exec()
if err != nil {
p.logger.Error("failed exec a stmt", "error", err)
@@ -46,19 +79,19 @@ func decodeUnsafe(bs []byte) []float32 {
return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4)
}
func (p ProviderSQL) SearchClosest(q [5120]float32) (*models.VectorRow, error) {
stmt, _, err := p.s3Conn.Prepare(`
SELECT
func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) {
stmt, _, err := p.s3Conn.Prepare(
fmt.Sprintf(`SELECT
id,
distance,
embedding,
slug,
raw_text
FROM vec_items
FROM %s
WHERE embedding MATCH ?
ORDER BY distance
LIMIT 4
`)
`, vecTableName))
if err != nil {
log.Fatal(err)
}
@@ -66,7 +99,10 @@ func (p ProviderSQL) SearchClosest(q [5120]float32) (*models.VectorRow, error) {
if err != nil {
log.Fatal(err)
}
stmt.BindBlob(1, query)
if err := stmt.BindBlob(1, query); err != nil {
p.logger.Error("failed to bind", "error", err)
return nil, err
}
resp := make([]models.VectorRow, 4)
i := 0
for stmt.Step() {

103
tui.go
View File

@@ -4,6 +4,7 @@ import (
"elefant/models"
"elefant/pngmeta"
"fmt"
"os"
"strconv"
"strings"
"time"
@@ -32,6 +33,7 @@ var (
indexPage = "indexPage"
helpPage = "helpPage"
renamePage = "renamePage"
RAGPage = "RAGPage "
// help text
helpText = `
[yellow]Esc[white]: send msg
@@ -130,6 +132,79 @@ func makeChatTable(chatList []string) *tview.Table {
return chatActTable
}
func makeRAGTable(fileList []string) *tview.Table {
actions := []string{"load", "rename", "delete"}
rows, cols := len(fileList), len(actions)+1
chatActTable := tview.NewTable().
SetBorders(true)
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
color := tcell.ColorWhite
if c < 1 {
chatActTable.SetCell(r, c,
tview.NewTableCell(fileList[r]).
SetTextColor(color).
SetAlign(tview.AlignCenter))
} else {
chatActTable.SetCell(r, c,
tview.NewTableCell(actions[c-1]).
SetTextColor(color).
SetAlign(tview.AlignCenter))
}
}
}
chatActTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
if key == tcell.KeyEsc || key == tcell.KeyF1 {
pages.RemovePage(RAGPage)
return
}
if key == tcell.KeyEnter {
chatActTable.SetSelectable(true, true)
}
}).SetSelectedFunc(func(row int, column int) {
tc := chatActTable.GetCell(row, column)
tc.SetTextColor(tcell.ColorRed)
chatActTable.SetSelectable(false, false)
fpath := fileList[row]
// notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text)
switch tc.Text {
case "load":
if err := loadRAG(fpath); err != nil {
logger.Error("failed to read history file", "chat", fpath)
pages.RemovePage(RAGPage)
return
}
pages.RemovePage(RAGPage)
colorText()
updateStatusLine()
return
case "rename":
pages.RemovePage(RAGPage)
pages.AddPage(renamePage, renameWindow, true, true)
return
case "delete":
sc, ok := chatMap[fpath]
if !ok {
// no chat found
pages.RemovePage(RAGPage)
return
}
if err := store.RemoveChat(sc.ID); err != nil {
logger.Error("failed to remove chat from db", "chat_id", sc.ID, "chat_name", sc.Name)
}
if err := notifyUser("chat deleted", fpath+" was deleted"); err != nil {
logger.Error("failed to send notification", "error", err)
}
pages.RemovePage(RAGPage)
return
default:
pages.RemovePage(RAGPage)
return
}
})
return chatActTable
}
// // code block colors get interrupted by " & *
// func codeBlockColor(text string) string {
// fi := strings.Index(text, "```")
@@ -153,7 +228,7 @@ func colorText() {
}
func updateStatusLine() {
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName))
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled))
}
func initSysCards() ([]string, error) {
@@ -379,6 +454,7 @@ func init() {
textView.SetText(chatToText(cfg.ShowSys))
colorText()
textView.ScrollToEnd()
// init sysmap
_, err := initSysCards()
if err != nil {
logger.Error("failed to init sys cards", "error", err)
@@ -456,6 +532,12 @@ func init() {
pages.AddPage(indexPage, indexPickWindow, true, true)
return nil
}
if event.Key() == tcell.KeyF11 {
// xor
cfg.RAGEnabled = cfg.RAGEnabled != true
updateStatusLine()
return nil
}
if event.Key() == tcell.KeyF12 {
// help window cheatsheet
pages.AddPage(helpPage, helpView, true, true)
@@ -496,6 +578,25 @@ func init() {
updateStatusLine()
return nil
}
if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" {
// rag load
// menu of the text files from defined rag directory
files, err := os.ReadDir(cfg.RAGDir)
if err != nil {
logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err)
return nil
}
fileList := []string{}
for _, f := range files {
if f.IsDir() {
continue
}
fileList = append(fileList, f.Name())
}
chatRAGTable := makeRAGTable(fileList)
pages.AddPage(RAGPage, chatRAGTable, true, true)
return nil
}
// cannot send msg in editMode or botRespMode
if event.Key() == tcell.KeyEscape && !editMode && !botRespMode {
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName))