Enha: botresp, toolresp to atomic

This commit is contained in:
Grail Finder
2026-03-08 07:13:27 +03:00
parent 23cb8f2578
commit c200c9328c
3 changed files with 25 additions and 23 deletions

28
bot.go
View File

@@ -22,6 +22,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@@ -40,7 +41,7 @@ var (
store storage.FullRepo store storage.FullRepo
defaultFirstMsg = "Hello! What can I do for you?" defaultFirstMsg = "Hello! What can I do for you?"
defaultStarter = []models.RoleMsg{} defaultStarter = []models.RoleMsg{}
interruptResp = false interruptResp atomic.Bool
ragger *rag.RAG ragger *rag.RAG
chunkParser ChunkParser chunkParser ChunkParser
lastToolCall *models.FuncCall lastToolCall *models.FuncCall
@@ -643,7 +644,7 @@ func sendMsgToLLM(body io.Reader) {
// continue // continue
} }
if len(line) <= 1 { if len(line) <= 1 {
if interruptResp { if interruptResp.Load() {
goto interrupt // get unstuck from bad connection goto interrupt // get unstuck from bad connection
} }
continue // skip \n continue // skip \n
@@ -736,8 +737,7 @@ func sendMsgToLLM(body io.Reader) {
lastToolCall.ID = chunk.ToolID lastToolCall.ID = chunk.ToolID
} }
interrupt: interrupt:
if interruptResp { // read bytes, so it would not get into beginning of the next req if interruptResp.Load() { // read bytes, so it would not get into beginning of the next req
// interruptResp = false
logger.Info("interrupted bot response", "chunk_counter", counter) logger.Info("interrupted bot response", "chunk_counter", counter)
streamDone <- true streamDone <- true
break break
@@ -770,14 +770,14 @@ func showSpinner() {
if cfg.WriteNextMsgAsCompletionAgent != "" { if cfg.WriteNextMsgAsCompletionAgent != "" {
botPersona = cfg.WriteNextMsgAsCompletionAgent botPersona = cfg.WriteNextMsgAsCompletionAgent
} }
for botRespMode || toolRunningMode { for botRespMode.Load() || toolRunningMode.Load() {
time.Sleep(400 * time.Millisecond) time.Sleep(400 * time.Millisecond)
spin := i % len(spinners) spin := i % len(spinners)
app.QueueUpdateDraw(func() { app.QueueUpdateDraw(func() {
switch { switch {
case toolRunningMode: case toolRunningMode.Load():
textArea.SetTitle(spinners[spin] + " tool") textArea.SetTitle(spinners[spin] + " tool")
case botRespMode: case botRespMode.Load():
textArea.SetTitle(spinners[spin] + " " + botPersona + " (F6 to interrupt)") textArea.SetTitle(spinners[spin] + " " + botPersona + " (F6 to interrupt)")
default: default:
textArea.SetTitle(spinners[spin] + " input") textArea.SetTitle(spinners[spin] + " input")
@@ -791,8 +791,8 @@ func showSpinner() {
} }
func chatRound(r *models.ChatRoundReq) error { func chatRound(r *models.ChatRoundReq) error {
interruptResp = false interruptResp.Store(false)
botRespMode = true botRespMode.Store(true)
go showSpinner() go showSpinner()
updateStatusLine() updateStatusLine()
botPersona := cfg.AssistantRole botPersona := cfg.AssistantRole
@@ -800,7 +800,7 @@ func chatRound(r *models.ChatRoundReq) error {
botPersona = cfg.WriteNextMsgAsCompletionAgent botPersona = cfg.WriteNextMsgAsCompletionAgent
} }
defer func() { defer func() {
botRespMode = false botRespMode.Store(false)
ClearImageAttachment() ClearImageAttachment()
}() }()
// check that there is a model set to use if is not local // check that there is a model set to use if is not local
@@ -928,7 +928,7 @@ out:
} }
lastRespStats = nil lastRespStats = nil
} }
botRespMode = false botRespMode.Store(false)
if r.Resume { if r.Resume {
chatBody.Messages[len(chatBody.Messages)-1].Content += respText.String() chatBody.Messages[len(chatBody.Messages)-1].Content += respText.String()
updatedMsg := chatBody.Messages[len(chatBody.Messages)-1] updatedMsg := chatBody.Messages[len(chatBody.Messages)-1]
@@ -957,7 +957,7 @@ out:
} }
// Strip think blocks before parsing for tool calls // Strip think blocks before parsing for tool calls
respTextNoThink := thinkBlockRE.ReplaceAllString(respText.String(), "") respTextNoThink := thinkBlockRE.ReplaceAllString(respText.String(), "")
if interruptResp { if interruptResp.Load() {
return nil return nil
} }
if findCall(respTextNoThink, toolResp.String()) { if findCall(respTextNoThink, toolResp.String()) {
@@ -1192,9 +1192,9 @@ func findCall(msg, toolCall string) bool {
} }
// Show tool call progress indicator before execution // Show tool call progress indicator before execution
fmt.Fprintf(textView, "\n[yellow::i][tool: %s...][-:-:-]", fc.Name) fmt.Fprintf(textView, "\n[yellow::i][tool: %s...][-:-:-]", fc.Name)
toolRunningMode = true toolRunningMode.Store(true)
resp := callToolWithAgent(fc.Name, fc.Args) resp := callToolWithAgent(fc.Name, fc.Args)
toolRunningMode = false toolRunningMode.Store(false)
toolMsg := string(resp) toolMsg := string(resp)
logger.Info("llm used a tool call", "tool_name", fc.Name, "too_args", fc.Args, "id", fc.ID, "tool_resp", toolMsg) logger.Info("llm used a tool call", "tool_name", fc.Name, "too_args", fc.Args, "id", fc.ID, "tool_resp", toolMsg)
// Create tool response message with the proper tool_call_id // Create tool response message with the proper tool_call_id

View File

@@ -1,13 +1,15 @@
package main package main
import ( import (
"sync/atomic"
"github.com/rivo/tview" "github.com/rivo/tview"
) )
var ( var (
boolColors = map[bool]string{true: "green", false: "red"} boolColors = map[bool]string{true: "green", false: "red"}
botRespMode = false botRespMode atomic.Bool
toolRunningMode = false toolRunningMode atomic.Bool
editMode = false editMode = false
roleEditMode = false roleEditMode = false
injectRole = true injectRole = true

14
tui.go
View File

@@ -731,7 +731,7 @@ func initTUI() {
updateStatusLine() updateStatusLine()
return nil return nil
} }
if event.Key() == tcell.KeyF2 && !botRespMode { if event.Key() == tcell.KeyF2 && !botRespMode.Load() {
// regen last msg // regen last msg
if len(chatBody.Messages) == 0 { if len(chatBody.Messages) == 0 {
showToast("info", "no messages to regenerate") showToast("info", "no messages to regenerate")
@@ -748,7 +748,7 @@ func initTUI() {
chatRoundChan <- &models.ChatRoundReq{Role: cfg.UserRole, Regen: true} chatRoundChan <- &models.ChatRoundReq{Role: cfg.UserRole, Regen: true}
return nil return nil
} }
if event.Key() == tcell.KeyF3 && !botRespMode { if event.Key() == tcell.KeyF3 && !botRespMode.Load() {
// delete last msg // delete last msg
// check textarea text; if it ends with bot icon delete only icon: // check textarea text; if it ends with bot icon delete only icon:
text := textView.GetText(true) text := textView.GetText(true)
@@ -804,9 +804,9 @@ func initTUI() {
return nil return nil
} }
if event.Key() == tcell.KeyF6 { if event.Key() == tcell.KeyF6 {
interruptResp = true interruptResp.Store(true)
botRespMode = false botRespMode.Store(false)
toolRunningMode = false toolRunningMode.Store(false)
return nil return nil
} }
if event.Key() == tcell.KeyF7 { if event.Key() == tcell.KeyF7 {
@@ -1101,7 +1101,7 @@ func initTUI() {
return nil return nil
} }
// cannot send msg in editMode or botRespMode // cannot send msg in editMode or botRespMode
if event.Key() == tcell.KeyEscape && !editMode && !botRespMode { if event.Key() == tcell.KeyEscape && !editMode && !botRespMode.Load() {
if shellMode { if shellMode {
cmdText := shellInput.GetText() cmdText := shellInput.GetText()
if cmdText != "" { if cmdText != "" {
@@ -1167,7 +1167,7 @@ func initTUI() {
app.SetFocus(focusSwitcher[currentF]) app.SetFocus(focusSwitcher[currentF])
return nil return nil
} }
if isASCII(string(event.Rune())) && !botRespMode { if isASCII(string(event.Rune())) && !botRespMode.Load() {
return event return event
} }
return event return event