Enha: /chat /completions tool calls to live in peace
This commit is contained in:
44
bot.go
44
bot.go
@@ -17,7 +17,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -44,6 +43,7 @@ var (
|
||||
interruptResp = false
|
||||
ragger *rag.RAG
|
||||
chunkParser ChunkParser
|
||||
lastToolCall *models.FuncCall
|
||||
//nolint:unused // TTS_ENABLED conditionally uses this
|
||||
orator extra.Orator
|
||||
asr extra.STT
|
||||
@@ -171,7 +171,7 @@ func sendMsgToLLM(body io.Reader) {
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken())
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes)))
|
||||
// req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes)))
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
// nolint
|
||||
// resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
|
||||
@@ -253,6 +253,9 @@ func sendMsgToLLM(body io.Reader) {
|
||||
answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n")
|
||||
chunkChan <- answerText
|
||||
openAIToolChan <- chunk.ToolChunk
|
||||
if chunk.FuncName != "" {
|
||||
lastToolCall.Name = chunk.FuncName
|
||||
}
|
||||
interrupt:
|
||||
if interruptResp { // read bytes, so it would not get into beginning of the next req
|
||||
interruptResp = false
|
||||
@@ -409,22 +412,32 @@ out:
|
||||
if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil {
|
||||
logger.Warn("failed to update storage", "error", err, "name", activeChatName)
|
||||
}
|
||||
// INFO: for completion only; openai has it's own tool struct
|
||||
findCall(respText.String(), toolResp.String(), tv)
|
||||
}
|
||||
|
||||
func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
fc := models.FuncCall{}
|
||||
jsStr := toolCallRE.FindString(msg)
|
||||
if jsStr == "" {
|
||||
return
|
||||
}
|
||||
prefix := "__tool_call__\n"
|
||||
suffix := "\n__tool_call__"
|
||||
jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
|
||||
if err := json.Unmarshal([]byte(jsStr), &fc); err != nil {
|
||||
logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr)
|
||||
return
|
||||
fc := &models.FuncCall{}
|
||||
if toolCall != "" {
|
||||
openAIToolMap := make(map[string]string)
|
||||
// respect tool call
|
||||
if err := json.Unmarshal([]byte(toolCall), &openAIToolMap); err != nil {
|
||||
logger.Error("failed to unmarshal openai tool call", "call", toolCall, "error", err)
|
||||
return
|
||||
}
|
||||
lastToolCall.Args = openAIToolMap
|
||||
fc = lastToolCall
|
||||
} else {
|
||||
jsStr := toolCallRE.FindString(msg)
|
||||
if jsStr == "" {
|
||||
return
|
||||
}
|
||||
prefix := "__tool_call__\n"
|
||||
suffix := "\n__tool_call__"
|
||||
jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
|
||||
if err := json.Unmarshal([]byte(jsStr), &fc); err != nil {
|
||||
logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr)
|
||||
return
|
||||
}
|
||||
}
|
||||
// call a func
|
||||
f, ok := fnMap[fc.Name]
|
||||
@@ -433,7 +446,7 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||
chatRound(m, cfg.ToolRole, tv, false, false)
|
||||
return
|
||||
}
|
||||
resp := f(fc.Args...)
|
||||
resp := f(fc.Args)
|
||||
toolMsg := fmt.Sprintf("tool response: %+v", string(resp))
|
||||
chatRound(toolMsg, cfg.ToolRole, tv, false, false)
|
||||
}
|
||||
@@ -550,6 +563,7 @@ func init() {
|
||||
logger.Error("failed to load chat", "error", err)
|
||||
return
|
||||
}
|
||||
lastToolCall = &models.FuncCall{}
|
||||
lastChat := loadOldChatOrGetNew()
|
||||
chatBody = &models.ChatBody{
|
||||
Model: "modelname",
|
||||
|
||||
Reference in New Issue
Block a user