Feat: tool chunk channel for openai tool calls

This commit is contained in:
Grail Finder
2025-08-08 10:51:14 +03:00
parent 14558f98cd
commit 589dfdda3f
2 changed files with 44 additions and 27 deletions

24
bot.go
View File

@@ -34,6 +34,7 @@ var (
logLevel = new(slog.LevelVar) logLevel = new(slog.LevelVar)
activeChatName string activeChatName string
chunkChan = make(chan string, 10) chunkChan = make(chan string, 10)
openAIToolChan = make(chan string, 10)
streamDone = make(chan bool, 1) streamDone = make(chan bool, 1)
chatBody *models.ChatBody chatBody *models.ChatBody
store storage.FullRepo store storage.FullRepo
@@ -189,8 +190,8 @@ func sendMsgToLLM(body io.Reader) {
for { for {
var ( var (
answerText string answerText string
content string
stop bool stop bool
chunk *models.TextChunk
) )
counter++ counter++
// to stop from spiriling in infinity read of bad bytes that happens with poor connection // to stop from spiriling in infinity read of bad bytes that happens with poor connection
@@ -225,7 +226,7 @@ func sendMsgToLLM(body io.Reader) {
if bytes.Equal(line, []byte("ROUTER PROCESSING\n")) { if bytes.Equal(line, []byte("ROUTER PROCESSING\n")) {
continue continue
} }
content, stop, err = chunkParser.ParseChunk(line) chunk, err = chunkParser.ParseChunk(line)
if err != nil { if err != nil {
logger.Error("error parsing response body", "error", err, logger.Error("error parsing response body", "error", err,
"line", string(line), "url", cfg.CurrentAPI) "line", string(line), "url", cfg.CurrentAPI)
@@ -239,18 +240,19 @@ func sendMsgToLLM(body io.Reader) {
break break
} }
if stop { if stop {
if content != "" { if chunk.Chunk != "" {
logger.Warn("text inside of finish llmchunk", "chunk", content, "counter", counter) logger.Warn("text inside of finish llmchunk", "chunk", chunk, "counter", counter)
} }
streamDone <- true streamDone <- true
break break
} }
if counter == 0 { if counter == 0 {
content = strings.TrimPrefix(content, " ") chunk.Chunk = strings.TrimPrefix(chunk.Chunk, " ")
} }
// bot sends way too many \n // bot sends way too many \n
answerText = strings.ReplaceAll(content, "\n\n", "\n") answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n")
chunkChan <- answerText chunkChan <- answerText
openAIToolChan <- chunk.ToolChunk
interrupt: interrupt:
if interruptResp { // read bytes, so it would not get into beginning of the next req if interruptResp { // read bytes, so it would not get into beginning of the next req
interruptResp = false interruptResp = false
@@ -362,6 +364,7 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
} }
} }
respText := strings.Builder{} respText := strings.Builder{}
toolResp := strings.Builder{}
out: out:
for { for {
select { select {
@@ -374,6 +377,10 @@ out:
// audioStream.TextChan <- chunk // audioStream.TextChan <- chunk
extra.TTSTextChan <- chunk extra.TTSTextChan <- chunk
} }
case toolChunk := <-openAIToolChan:
fmt.Fprint(tv, toolChunk)
toolResp.WriteString(toolChunk)
tv.ScrollToEnd()
case <-streamDone: case <-streamDone:
botRespMode = false botRespMode = false
if cfg.TTS_ENABLED { if cfg.TTS_ENABLED {
@@ -402,10 +409,11 @@ out:
if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil {
logger.Warn("failed to update storage", "error", err, "name", activeChatName) logger.Warn("failed to update storage", "error", err, "name", activeChatName)
} }
findCall(respText.String(), tv) // INFO: for completion only; openai has it's own tool struct
findCall(respText.String(), toolResp.String(), tv)
} }
func findCall(msg string, tv *tview.TextView) { func findCall(msg, toolCall string, tv *tview.TextView) {
fc := models.FuncCall{} fc := models.FuncCall{}
jsStr := toolCallRE.FindString(msg) jsStr := toolCallRE.FindString(msg)
if jsStr == "" { if jsStr == "" {

45
llm.go
View File

@@ -246,22 +246,27 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
return bytes.NewReader(data), nil return bytes.NewReader(data), nil
} }
func (ds DeepSeekerChat) ParseChunk(data []byte) (string, bool, error) { func (ds DeepSeekerChat) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.DSChatStreamResp{} llmchunk := models.DSChatStreamResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil { if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data)) logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err return nil, err
} }
resp := &models.TextChunk{}
if llmchunk.Choices[0].FinishReason != "" { if llmchunk.Choices[0].FinishReason != "" {
if llmchunk.Choices[0].Delta.Content != "" { if llmchunk.Choices[0].Delta.Content != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk) logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
} }
return llmchunk.Choices[0].Delta.Content, true, nil resp.Chunk = llmchunk.Choices[0].Delta.Content
} resp.Finished = true
} else {
if llmchunk.Choices[0].Delta.ReasoningContent != "" { if llmchunk.Choices[0].Delta.ReasoningContent != "" {
return llmchunk.Choices[0].Delta.ReasoningContent, false, nil resp.Chunk = llmchunk.Choices[0].Delta.ReasoningContent
} else {
resp.Chunk = llmchunk.Choices[0].Delta.Content
} }
return llmchunk.Choices[0].Delta.Content, false, nil }
return resp, nil
} }
func (ds DeepSeekerChat) GetToken() string { func (ds DeepSeekerChat) GetToken() string {
@@ -316,20 +321,22 @@ func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
} }
// openrouter // openrouter
func (or OpenRouterCompletion) ParseChunk(data []byte) (string, bool, error) { func (or OpenRouterCompletion) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.OpenRouterCompletionResp{} llmchunk := models.OpenRouterCompletionResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil { if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data)) logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err return nil, err
}
resp := &models.TextChunk{
Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Text,
} }
content := llmchunk.Choices[len(llmchunk.Choices)-1].Text
if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" {
if content != "" { if resp.Chunk != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk) logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
} }
return content, true, nil resp.Finished = true
} }
return content, false, nil return resp, nil
} }
func (or OpenRouterCompletion) GetToken() string { func (or OpenRouterCompletion) GetToken() string {
@@ -381,20 +388,22 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
} }
// chat // chat
func (or OpenRouterChat) ParseChunk(data []byte) (string, bool, error) { func (or OpenRouterChat) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.OpenRouterChatResp{} llmchunk := models.OpenRouterChatResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil { if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data)) logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err return nil, err
}
resp := &models.TextChunk{
Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content,
} }
content := llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content
if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" {
if content != "" { if resp.Chunk != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk) logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
} }
return content, true, nil resp.Finished = true
} }
return content, false, nil return resp, nil
} }
func (or OpenRouterChat) GetToken() string { func (or OpenRouterChat) GetToken() string {