Feat: add toml config

This commit is contained in:
Grail Finder
2024-11-27 20:16:58 +03:00
parent 55007d27f8
commit 14d706f94a
7 changed files with 95 additions and 51 deletions

67
bot.go
View File

@@ -3,6 +3,7 @@ package main
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"elefant/config"
"elefant/models" "elefant/models"
"elefant/storage" "elefant/storage"
"encoding/json" "encoding/json"
@@ -22,27 +23,16 @@ var httpClient = http.Client{
} }
var ( var (
logger *slog.Logger cfg *config.Config
userRole = "user" logger *slog.Logger
assistantRole = "assistant" chunkLimit = 1000
toolRole = "tool" activeChatName string
assistantIcon = "<🤖>: " chunkChan = make(chan string, 10)
userIcon = "<user>: " streamDone = make(chan bool, 1)
// TODO: pass as an cli arg or have config chatBody *models.ChatBody
APIURL = "http://localhost:8080/v1/chat/completions" store storage.FullRepo
logFileName = "log.txt" defaultFirstMsg = "Hello! What can I do for you?"
showSystemMsgs = true defaultStarter = []models.MessagesStory{}
chunkLimit = 1000
activeChatName string
chunkChan = make(chan string, 10)
streamDone = make(chan bool, 1)
chatBody *models.ChatBody
store storage.FullRepo
defaultFirstMsg = "Hello! What can I do for you?"
defaultStarter = []models.MessagesStory{
{Role: "system", Content: systemMsg},
{Role: assistantRole, Content: defaultFirstMsg},
}
defaultStarterBytes = []byte{} defaultStarterBytes = []byte{}
interruptResp = false interruptResp = false
) )
@@ -64,14 +54,14 @@ func formMsg(chatBody *models.ChatBody, newMsg, role string) io.Reader {
// func sendMsgToLLM(body io.Reader) (*models.LLMRespChunk, error) { // func sendMsgToLLM(body io.Reader) (*models.LLMRespChunk, error) {
func sendMsgToLLM(body io.Reader) (any, error) { func sendMsgToLLM(body io.Reader) (any, error) {
resp, err := httpClient.Post(APIURL, "application/json", body) resp, err := httpClient.Post(cfg.APIURL, "application/json", body)
if err != nil { if err != nil {
logger.Error("llamacpp api", "error", err) logger.Error("llamacpp api", "error", err)
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
llmResp := []models.LLMRespChunk{} llmResp := []models.LLMRespChunk{}
// chunkChan <- assistantIcon // chunkChan <- cfg.AssistantIcon
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
counter := 0 counter := 0
for { for {
@@ -128,7 +118,7 @@ func chatRound(userMsg, role string, tv *tview.TextView) {
go sendMsgToLLM(reader) go sendMsgToLLM(reader)
if userMsg != "" { // no need to write assistant icon since we continue old message if userMsg != "" { // no need to write assistant icon since we continue old message
fmt.Fprintf(tv, fmt.Sprintf("(%d) ", len(chatBody.Messages))) fmt.Fprintf(tv, fmt.Sprintf("(%d) ", len(chatBody.Messages)))
fmt.Fprintf(tv, assistantIcon) fmt.Fprintf(tv, cfg.AssistantIcon)
} }
respText := strings.Builder{} respText := strings.Builder{}
out: out:
@@ -145,7 +135,7 @@ out:
} }
botRespMode = false botRespMode = false
chatBody.Messages = append(chatBody.Messages, models.MessagesStory{ chatBody.Messages = append(chatBody.Messages, models.MessagesStory{
Role: assistantRole, Content: respText.String(), Role: cfg.AssistantRole, Content: respText.String(),
}) })
// bot msg is done; // bot msg is done;
// now check it for func call // now check it for func call
@@ -174,18 +164,18 @@ func findCall(msg string, tv *tview.TextView) {
f, ok := fnMap[fc.Name] f, ok := fnMap[fc.Name]
if !ok { if !ok {
m := fmt.Sprintf("%s is not implemented", fc.Name) m := fmt.Sprintf("%s is not implemented", fc.Name)
chatRound(m, toolRole, tv) chatRound(m, cfg.ToolRole, tv)
return return
} }
resp := f(fc.Args...) resp := f(fc.Args...)
toolMsg := fmt.Sprintf("tool response: %+v", string(resp)) toolMsg := fmt.Sprintf("tool response: %+v", string(resp))
chatRound(toolMsg, toolRole, tv) chatRound(toolMsg, cfg.ToolRole, tv)
} }
func chatToTextSlice(showSys bool) []string { func chatToTextSlice(showSys bool) []string {
resp := make([]string, len(chatBody.Messages)) resp := make([]string, len(chatBody.Messages))
for i, msg := range chatBody.Messages { for i, msg := range chatBody.Messages {
if !showSys && (msg.Role != assistantRole && msg.Role != userRole) { if !showSys && (msg.Role != cfg.AssistantRole && msg.Role != cfg.UserRole) {
continue continue
} }
resp[i] = msg.ToText(i) resp[i] = msg.ToText(i)
@@ -201,14 +191,14 @@ func chatToText(showSys bool) string {
func textToMsg(rawMsg string) models.MessagesStory { func textToMsg(rawMsg string) models.MessagesStory {
msg := models.MessagesStory{} msg := models.MessagesStory{}
// system and tool? // system and tool?
if strings.HasPrefix(rawMsg, assistantIcon) { if strings.HasPrefix(rawMsg, cfg.AssistantIcon) {
msg.Role = assistantRole msg.Role = cfg.AssistantRole
msg.Content = strings.TrimPrefix(rawMsg, assistantIcon) msg.Content = strings.TrimPrefix(rawMsg, cfg.AssistantIcon)
return msg return msg
} }
if strings.HasPrefix(rawMsg, userIcon) { if strings.HasPrefix(rawMsg, cfg.UserIcon) {
msg.Role = userRole msg.Role = cfg.UserRole
msg.Content = strings.TrimPrefix(rawMsg, userIcon) msg.Content = strings.TrimPrefix(rawMsg, cfg.UserIcon)
return msg return msg
} }
return msg return msg
@@ -224,9 +214,14 @@ func textSliceToChat(chat []string) []models.MessagesStory {
} }
func init() { func init() {
file, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) cfg = config.LoadConfigOrDefault("config.example.toml")
defaultStarter = []models.MessagesStory{
{Role: "system", Content: systemMsg},
{Role: cfg.AssistantRole, Content: defaultFirstMsg},
}
file, err := os.OpenFile(cfg.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
logger.Error("failed to open log file", "error", err, "filename", logFileName) logger.Error("failed to open log file", "error", err, "filename", cfg.LogFile)
return return
} }
defaultStarterBytes, err = json.Marshal(defaultStarter) defaultStarterBytes, err = json.Marshal(defaultStarter)

9
config.example.toml Normal file
View File

@@ -0,0 +1,9 @@
APIURL = "http://localhost:8080/v1/chat/completions"
ShowSys = true
LogFile = "log.txt"
UserRole = "user"
ToolRole = "tool"
AssistantRole = "assistant"
AssistantIcon = "<🤖>: "
UserIcon = "<user>: "
ToolIcon = "<>>: "

37
config/config.go Normal file
View File

@@ -0,0 +1,37 @@
package config
import (
"fmt"
"github.com/BurntSushi/toml"
)
type Config struct {
APIURL string `toml:"APIURL"`
ShowSys bool `toml:"ShowSys"`
LogFile string `toml:"LogFile"`
UserRole string `toml:"UserRole"`
ToolRole string `toml:"ToolRole"`
AssistantRole string `toml:"AssistantRole"`
AssistantIcon string `toml:"AssistantIcon"`
UserIcon string `toml:"UserIcon"`
ToolIcon string `toml:"ToolIcon"`
}
func LoadConfigOrDefault(fn string) *Config {
if fn == "" {
fn = "config.toml"
}
config := &Config{}
_, err := toml.DecodeFile(fn, &config)
if err != nil {
fmt.Println("failed to read config from file, loading default")
config.APIURL = "http://localhost:8080/v1/chat/completions"
config.ShowSys = true
config.LogFile = "log.txt"
config.UserRole = "user"
config.ToolRole = "tool"
config.AssistantRole = "assistant"
}
return config
}

1
go.mod
View File

@@ -3,6 +3,7 @@ module elefant
go 1.23.2 go 1.23.2
require ( require (
github.com/BurntSushi/toml v1.4.0
github.com/gdamore/tcell/v2 v2.7.4 github.com/gdamore/tcell/v2 v2.7.4
github.com/glebarez/go-sqlite v1.22.0 github.com/glebarez/go-sqlite v1.22.0
github.com/jmoiron/sqlx v1.4.0 github.com/jmoiron/sqlx v1.4.0

2
go.sum
View File

@@ -1,5 +1,7 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko= github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko=

View File

@@ -61,7 +61,7 @@ also:
- others do; - others do;
*/ */
func memorise(args ...string) []byte { func memorise(args ...string) []byte {
agent := assistantRole agent := cfg.AssistantRole
if len(args) < 2 { if len(args) < 2 {
msg := "not enough args to call memorise tool; need topic and data to remember" msg := "not enough args to call memorise tool; need topic and data to remember"
logger.Error(msg) logger.Error(msg)
@@ -79,7 +79,7 @@ func memorise(args ...string) []byte {
} }
func recall(args ...string) []byte { func recall(args ...string) []byte {
agent := assistantRole agent := cfg.AssistantRole
if len(args) < 1 { if len(args) < 1 {
logger.Warn("not enough args to call recall tool") logger.Warn("not enough args to call recall tool")
return nil return nil
@@ -94,7 +94,7 @@ func recall(args ...string) []byte {
} }
func recallTopics(args ...string) []byte { func recallTopics(args ...string) []byte {
agent := assistantRole agent := cfg.AssistantRole
topics, err := store.RecallTopics(agent) topics, err := store.RecallTopics(agent)
if err != nil { if err != nil {
logger.Error("failed to use tool", "error", err, "args", args) logger.Error("failed to use tool", "error", err, "args", args)

24
tui.go
View File

@@ -82,7 +82,7 @@ func init() {
} }
// set chat body // set chat body
chatBody.Messages = defaultStarter chatBody.Messages = defaultStarter
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
newChat := &models.Chat{ newChat := &models.Chat{
ID: id + 1, ID: id + 1,
Name: fmt.Sprintf("%v_%v", "new", time.Now().Unix()), Name: fmt.Sprintf("%v_%v", "new", time.Now().Unix()),
@@ -111,7 +111,7 @@ func init() {
return return
} }
chatBody.Messages = history chatBody.Messages = history
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
activeChatName = fn activeChatName = fn
pages.RemovePage("history") pages.RemovePage("history")
return return
@@ -134,7 +134,7 @@ func init() {
} }
chatBody.Messages[0].Content = sysMsg chatBody.Messages[0].Content = sysMsg
// replace textview // replace textview
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
pages.RemovePage("sys") pages.RemovePage("sys")
} }
}) })
@@ -152,7 +152,7 @@ func init() {
} }
chatBody.Messages[selectedIndex].Content = editedMsg chatBody.Messages[selectedIndex].Content = editedMsg
// change textarea // change textarea
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
pages.RemovePage("editArea") pages.RemovePage("editArea")
editMode = false editMode = false
return nil return nil
@@ -233,7 +233,7 @@ func init() {
// //
textArea.SetMovedFunc(updateStatusLine) textArea.SetMovedFunc(updateStatusLine)
updateStatusLine() updateStatusLine()
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
textView.ScrollToEnd() textView.ScrollToEnd()
app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyF1 { if event.Key() == tcell.KeyF1 {
@@ -251,14 +251,14 @@ func init() {
if event.Key() == tcell.KeyF2 { if event.Key() == tcell.KeyF2 {
// regen last msg // regen last msg
chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1]
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
go chatRound("", userRole, textView) go chatRound("", cfg.UserRole, textView)
return nil return nil
} }
if event.Key() == tcell.KeyF3 && !botRespMode { if event.Key() == tcell.KeyF3 && !botRespMode {
// delete last msg // delete last msg
chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1]
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
return nil return nil
} }
if event.Key() == tcell.KeyF4 { if event.Key() == tcell.KeyF4 {
@@ -268,9 +268,9 @@ func init() {
return nil return nil
} }
if event.Key() == tcell.KeyF5 { if event.Key() == tcell.KeyF5 {
// switch showSystemMsgs // switch cfg.ShowSys
showSystemMsgs = !showSystemMsgs cfg.ShowSys = !cfg.ShowSys
textView.SetText(chatToText(showSystemMsgs)) textView.SetText(chatToText(cfg.ShowSys))
} }
if event.Key() == tcell.KeyF6 { if event.Key() == tcell.KeyF6 {
interruptResp = true interruptResp = true
@@ -317,7 +317,7 @@ func init() {
textView.ScrollToEnd() textView.ScrollToEnd()
} }
// update statue line // update statue line
go chatRound(msgText, userRole, textView) go chatRound(msgText, cfg.UserRole, textView)
return nil return nil
} }
if event.Key() == tcell.KeyPgUp || event.Key() == tcell.KeyPgDn { if event.Key() == tcell.KeyPgUp || event.Key() == tcell.KeyPgDn {