Enha: avoid recursion in llm calls
This commit is contained in:
137
bot.go
137
bot.go
@@ -25,7 +25,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/neurosnap/sentences/english"
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -33,9 +32,9 @@ var (
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
logLevel = new(slog.LevelVar)
|
||||
)
|
||||
var (
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
activeChatName string
|
||||
chatRoundChan = make(chan *models.ChatRoundReq, 1)
|
||||
chunkChan = make(chan string, 10)
|
||||
openAIToolChan = make(chan string, 10)
|
||||
streamDone = make(chan bool, 1)
|
||||
@@ -699,7 +698,23 @@ func roleToIcon(role string) string {
|
||||
return "<" + role + ">: "
|
||||
}
|
||||
|
||||
func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
|
||||
func chatWatcher(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case chatRoundReq := <-chatRoundChan:
|
||||
if err := chatRound(chatRoundReq); err != nil {
|
||||
logger.Error("failed to chatRound", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chatRound(r *models.ChatRoundReq) error {
|
||||
// chunkChan := make(chan string, 10)
|
||||
// openAIToolChan := make(chan string, 10)
|
||||
// streamDone := make(chan bool, 1)
|
||||
botRespMode = true
|
||||
botPersona := cfg.AssistantRole
|
||||
if cfg.WriteNextMsgAsCompletionAgent != "" {
|
||||
@@ -707,32 +722,23 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
|
||||
}
|
||||
defer func() { botRespMode = false }()
|
||||
// check that there is a model set to use if is not local
|
||||
if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
|
||||
if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
|
||||
if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
|
||||
logger.Warn("failed ot notify user", "error", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
choseChunkParser()
|
||||
reader, err := chunkParser.FormMsg(userMsg, role, resume)
|
||||
reader, err := chunkParser.FormMsg(r.UserMsg, r.Role, r.Resume)
|
||||
if reader == nil || err != nil {
|
||||
logger.Error("empty reader from msgs", "role", role, "error", err)
|
||||
return
|
||||
logger.Error("empty reader from msgs", "role", r.Role, "error", err)
|
||||
return err
|
||||
}
|
||||
if cfg.SkipLLMResp {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
go sendMsgToLLM(reader)
|
||||
logger.Debug("looking at vars in chatRound", "msg", userMsg, "regen", regen, "resume", resume)
|
||||
if !resume {
|
||||
fmt.Fprintf(tv, "\n[-:-:b](%d) ", len(chatBody.Messages))
|
||||
fmt.Fprint(tv, roleToIcon(botPersona))
|
||||
fmt.Fprint(tv, "[-:-:-]\n")
|
||||
logger.Debug("looking at vars in chatRound", "msg", r.UserMsg, "regen", r.Regen, "resume", r.Resume)
|
||||
if !r.Resume {
|
||||
fmt.Fprintf(textView, "\n[-:-:b](%d) ", len(chatBody.Messages))
|
||||
fmt.Fprint(textView, roleToIcon(botPersona))
|
||||
fmt.Fprint(textView, "[-:-:-]\n")
|
||||
if cfg.ThinkUse && !strings.Contains(cfg.CurrentAPI, "v1") {
|
||||
// fmt.Fprint(tv, "<think>")
|
||||
// fmt.Fprint(textView, "<think>")
|
||||
chunkChan <- "<think>"
|
||||
}
|
||||
}
|
||||
@@ -742,29 +748,29 @@ out:
|
||||
for {
|
||||
select {
|
||||
case chunk := <-chunkChan:
|
||||
fmt.Fprint(tv, chunk)
|
||||
fmt.Fprint(textView, chunk)
|
||||
respText.WriteString(chunk)
|
||||
if scrollToEndEnabled {
|
||||
tv.ScrollToEnd()
|
||||
textView.ScrollToEnd()
|
||||
}
|
||||
// Send chunk to audio stream handler
|
||||
if cfg.TTS_ENABLED {
|
||||
TTSTextChan <- chunk
|
||||
}
|
||||
case toolChunk := <-openAIToolChan:
|
||||
fmt.Fprint(tv, toolChunk)
|
||||
fmt.Fprint(textView, toolChunk)
|
||||
toolResp.WriteString(toolChunk)
|
||||
if scrollToEndEnabled {
|
||||
tv.ScrollToEnd()
|
||||
textView.ScrollToEnd()
|
||||
}
|
||||
case <-streamDone:
|
||||
// drain any remaining chunks from chunkChan before exiting
|
||||
for len(chunkChan) > 0 {
|
||||
chunk := <-chunkChan
|
||||
fmt.Fprint(tv, chunk)
|
||||
fmt.Fprint(textView, chunk)
|
||||
respText.WriteString(chunk)
|
||||
if scrollToEndEnabled {
|
||||
tv.ScrollToEnd()
|
||||
textView.ScrollToEnd()
|
||||
}
|
||||
if cfg.TTS_ENABLED {
|
||||
// Send chunk to audio stream handler
|
||||
@@ -780,7 +786,7 @@ out:
|
||||
}
|
||||
botRespMode = false
|
||||
// numbers in chatbody and displayed must be the same
|
||||
if resume {
|
||||
if r.Resume {
|
||||
chatBody.Messages[len(chatBody.Messages)-1].Content += respText.String()
|
||||
// lastM.Content = lastM.Content + respText.String()
|
||||
// Process the updated message to check for known_to tags in resumed response
|
||||
@@ -797,7 +803,9 @@ out:
|
||||
}
|
||||
logger.Debug("chatRound: before cleanChatBody", "messages_before_clean", len(chatBody.Messages))
|
||||
for i, msg := range chatBody.Messages {
|
||||
logger.Debug("chatRound: before cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
|
||||
logger.Debug("chatRound: before cleaning", "index", i,
|
||||
"role", msg.Role, "content_len", len(msg.Content),
|
||||
"has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
|
||||
}
|
||||
// // Clean null/empty messages to prevent API issues with endpoints like llama.cpp jinja template
|
||||
cleanChatBody()
|
||||
@@ -813,8 +821,9 @@ out:
|
||||
if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil {
|
||||
logger.Warn("failed to update storage", "error", err, "name", activeChatName)
|
||||
}
|
||||
// FIXME: recursive calls
|
||||
findCall(respText.String(), toolResp.String(), tv)
|
||||
if findCall(respText.String(), toolResp.String()) {
|
||||
return nil
|
||||
}
|
||||
// TODO: have a config attr
|
||||
// Check if this message was sent privately to specific characters
|
||||
// If so, trigger those characters to respond if that char is not controlled by user
|
||||
@@ -822,9 +831,10 @@ out:
|
||||
if cfg.AutoTurn {
|
||||
lastMsg := chatBody.Messages[len(chatBody.Messages)-1]
|
||||
if len(lastMsg.KnownTo) > 0 {
|
||||
triggerPrivateMessageResponses(lastMsg, tv)
|
||||
triggerPrivateMessageResponses(lastMsg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanChatBody removes messages with null or empty content to prevent API issues
|
||||
@@ -909,7 +919,8 @@ func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) {
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
// findCall: adds chatRoundReq into the chatRoundChan and returns true if does
|
||||
func findCall(msg, toolCall string) bool {
|
||||
fc := &models.FuncCall{}
|
||||
if toolCall != "" {
|
||||
// HTML-decode the tool call string to handle encoded characters like < -> <=
|
||||
@@ -927,8 +938,13 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
chatBody.Messages = append(chatBody.Messages, toolResponseMsg)
|
||||
// Clear the stored tool call ID after using it (no longer needed)
|
||||
// Trigger the assistant to continue processing with the error message
|
||||
chatRound("", cfg.AssistantRole, tv, false, false)
|
||||
return
|
||||
crr := &models.ChatRoundReq{
|
||||
Role: cfg.AssistantRole,
|
||||
}
|
||||
// provoke next llm msg after failed tool call
|
||||
chatRoundChan <- crr
|
||||
// chatRound("", cfg.AssistantRole, tv, false, false)
|
||||
return true
|
||||
}
|
||||
lastToolCall.Args = openAIToolMap
|
||||
fc = lastToolCall
|
||||
@@ -940,8 +956,8 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
}
|
||||
} else {
|
||||
jsStr := toolCallRE.FindString(msg)
|
||||
if jsStr == "" {
|
||||
return
|
||||
if jsStr == "" { // no tool call case
|
||||
return false
|
||||
}
|
||||
prefix := "__tool_call__\n"
|
||||
suffix := "\n__tool_call__"
|
||||
@@ -960,8 +976,13 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
chatBody.Messages = append(chatBody.Messages, toolResponseMsg)
|
||||
logger.Debug("findCall: added tool error response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "message_count_after_add", len(chatBody.Messages))
|
||||
// Trigger the assistant to continue processing with the error message
|
||||
chatRound("", cfg.AssistantRole, tv, false, false)
|
||||
return
|
||||
// chatRound("", cfg.AssistantRole, tv, false, false)
|
||||
crr := &models.ChatRoundReq{
|
||||
Role: cfg.AssistantRole,
|
||||
}
|
||||
// provoke next llm msg after failed tool call
|
||||
chatRoundChan <- crr
|
||||
return true
|
||||
}
|
||||
// Update lastToolCall with parsed function call
|
||||
lastToolCall.ID = fc.ID
|
||||
@@ -994,13 +1015,17 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
lastToolCall.ID = ""
|
||||
// Trigger the assistant to continue processing with the new tool response
|
||||
// by calling chatRound with empty content to continue the assistant's response
|
||||
chatRound("", cfg.AssistantRole, tv, false, false)
|
||||
return
|
||||
crr := &models.ChatRoundReq{
|
||||
Role: cfg.AssistantRole,
|
||||
}
|
||||
// failed to find tool
|
||||
chatRoundChan <- crr
|
||||
return true
|
||||
}
|
||||
resp := callToolWithAgent(fc.Name, fc.Args)
|
||||
toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting
|
||||
logger.Info("llm used tool call", "tool_resp", toolMsg, "tool_attrs", fc)
|
||||
fmt.Fprintf(tv, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
|
||||
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
|
||||
"\n\n", len(chatBody.Messages), cfg.ToolRole, toolMsg)
|
||||
// Create tool response message with the proper tool_call_id
|
||||
toolResponseMsg := models.RoleMsg{
|
||||
@@ -1014,7 +1039,11 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
lastToolCall.ID = ""
|
||||
// Trigger the assistant to continue processing with the new tool response
|
||||
// by calling chatRound with empty content to continue the assistant's response
|
||||
chatRound("", cfg.AssistantRole, tv, false, false)
|
||||
crr := &models.ChatRoundReq{
|
||||
Role: cfg.AssistantRole,
|
||||
}
|
||||
chatRoundChan <- crr
|
||||
return true
|
||||
}
|
||||
|
||||
func chatToTextSlice(messages []models.RoleMsg, showSys bool) []string {
|
||||
@@ -1163,10 +1192,12 @@ func summarizeAndStartNewChat() {
|
||||
}
|
||||
|
||||
func init() {
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
var err error
|
||||
cfg, err = config.LoadConfig("config.toml")
|
||||
if err != nil {
|
||||
fmt.Println("failed to load config.toml")
|
||||
cancel()
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
@@ -1178,6 +1209,8 @@ func init() {
|
||||
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
slog.Error("failed to open log file", "error", err, "filename", cfg.LogFile)
|
||||
cancel()
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
// load cards
|
||||
@@ -1188,13 +1221,17 @@ func init() {
|
||||
logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel}))
|
||||
store = storage.NewProviderSQL(cfg.DBPATH, logger)
|
||||
if store == nil {
|
||||
cancel()
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
ragger = rag.New(logger, store, cfg)
|
||||
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
|
||||
// load all chats in memory
|
||||
if _, err := loadHistoryChats(); err != nil {
|
||||
logger.Error("failed to load chat", "error", err)
|
||||
cancel()
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
lastToolCall = &models.FuncCall{}
|
||||
@@ -1215,11 +1252,12 @@ func init() {
|
||||
// Initialize scrollToEndEnabled based on config
|
||||
scrollToEndEnabled = cfg.AutoScrollEnabled
|
||||
go updateModelLists()
|
||||
go chatWatcher(ctx)
|
||||
}
|
||||
|
||||
// triggerPrivateMessageResponses checks if a message was sent privately to specific characters
|
||||
// and triggers those non-user characters to respond
|
||||
func triggerPrivateMessageResponses(msg models.RoleMsg, tv *tview.TextView) {
|
||||
func triggerPrivateMessageResponses(msg models.RoleMsg) {
|
||||
if cfg == nil || !cfg.CharSpecificContextEnabled {
|
||||
return
|
||||
}
|
||||
@@ -1237,6 +1275,11 @@ func triggerPrivateMessageResponses(msg models.RoleMsg, tv *tview.TextView) {
|
||||
// that indicates it's their turn
|
||||
triggerMsg := recipient + ":\n"
|
||||
// Call chatRound with the trigger message to make the recipient respond
|
||||
chatRound(triggerMsg, recipient, tv, false, false)
|
||||
// chatRound(triggerMsg, recipient, tv, false, false)
|
||||
crr := &models.ChatRoundReq{
|
||||
UserMsg: triggerMsg,
|
||||
Role: recipient,
|
||||
}
|
||||
chatRoundChan <- crr
|
||||
}
|
||||
}
|
||||
|
||||
13
helpfuncs.go
13
helpfuncs.go
@@ -279,3 +279,16 @@ func listChatRoles() []string {
|
||||
charset = append(charset, cbc...)
|
||||
return charset
|
||||
}
|
||||
|
||||
func deepseekModelValidator() error {
|
||||
if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
|
||||
if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
|
||||
if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
|
||||
logger.Warn("failed ot notify user", "error", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
6
llm.go
6
llm.go
@@ -363,6 +363,9 @@ func (ds DeepSeekerCompletion) GetToken() string {
|
||||
|
||||
func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader, error) {
|
||||
logger.Debug("formmsg deepseekercompletion", "link", cfg.CurrentAPI)
|
||||
if err := deepseekModelValidator(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg != "" { // otherwise let the bot to continue
|
||||
newMsg := models.RoleMsg{Role: role, Content: msg}
|
||||
newMsg = processMessageTag(newMsg)
|
||||
@@ -445,6 +448,9 @@ func (ds DeepSeekerChat) GetToken() string {
|
||||
|
||||
func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, error) {
|
||||
logger.Debug("formmsg deepseekerchat", "link", cfg.CurrentAPI)
|
||||
if err := deepseekModelValidator(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg != "" { // otherwise let the bot continue
|
||||
newMsg := models.RoleMsg{Role: role, Content: msg}
|
||||
newMsg = processMessageTag(newMsg)
|
||||
|
||||
@@ -540,3 +540,10 @@ func (lcp *LCPModels) ListModels() []string {
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
type ChatRoundReq struct {
|
||||
UserMsg string
|
||||
Role string
|
||||
Regen bool
|
||||
Resume bool
|
||||
}
|
||||
|
||||
9
tui.go
9
tui.go
@@ -873,7 +873,8 @@ func init() {
|
||||
// there is no case where user msg is regenerated
|
||||
// lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role
|
||||
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys))
|
||||
go chatRound("", cfg.UserRole, textView, true, false)
|
||||
// go chatRound("", cfg.UserRole, textView, true, false)
|
||||
chatRoundChan <- &models.ChatRoundReq{Role: cfg.UserRole}
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyF3 && !botRespMode {
|
||||
@@ -1176,7 +1177,8 @@ func init() {
|
||||
// INFO: continue bot/text message
|
||||
// without new role
|
||||
lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role
|
||||
go chatRound("", lastRole, textView, false, true)
|
||||
// go chatRound("", lastRole, textView, false, true)
|
||||
chatRoundChan <- &models.ChatRoundReq{Role: lastRole, Resume: true}
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyCtrlQ {
|
||||
@@ -1347,7 +1349,8 @@ func init() {
|
||||
}
|
||||
colorText()
|
||||
}
|
||||
go chatRound(msgText, persona, textView, false, false)
|
||||
// go chatRound(msgText, persona, textView, false, false)
|
||||
chatRoundChan <- &models.ChatRoundReq{Role: persona, UserMsg: msgText}
|
||||
// Also clear any image attachment after sending the message
|
||||
go func() {
|
||||
// Wait a short moment for the message to be processed, then clear the image attachment
|
||||
|
||||
Reference in New Issue
Block a user