Enha: agent request builder

This commit is contained in:
Grail Finder
2025-12-19 15:39:55 +03:00
parent a875abcf19
commit f779f03974
5 changed files with 431 additions and 31 deletions

View File

@@ -33,3 +33,13 @@ func RegisterB(toolName string, a AgenterB) {
func RegisterA(toolNames []string, a AgenterA) { func RegisterA(toolNames []string, a AgenterA) {
RegistryA[a] = toolNames RegistryA[a] = toolNames
} }
// Get returns the agent registered for the given tool name, or nil if none.
func Get(toolName string) AgenterB {
return RegistryB[toolName]
}
// Register is a convenience wrapper for RegisterB.
func Register(toolName string, a AgenterB) {
RegisterB(toolName, a)
}

View File

@@ -3,15 +3,32 @@ package agent
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"gf-lt/config" "gf-lt/config"
"gf-lt/models" "gf-lt/models"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
) )
var httpClient = &http.Client{} var httpClient = &http.Client{}
var defaultProps = map[string]float32{
"temperature": 0.8,
"dry_multiplier": 0.0,
"min_p": 0.05,
"n_predict": -1.0,
}
func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) {
isCompletion = strings.Contains(api, "/completion") && !strings.Contains(api, "/chat/completions")
isChat = strings.Contains(api, "/chat/completions")
isDeepSeek = strings.Contains(api, "deepseek.com")
isOpenRouter = strings.Contains(api, "openrouter.ai")
return
}
type AgentClient struct { type AgentClient struct {
cfg *config.Config cfg *config.Config
getToken func() string getToken func() string
@@ -31,38 +48,185 @@ func (ag *AgentClient) Log() *slog.Logger {
} }
func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) { func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) {
agentConvo := []models.RoleMsg{ b, err := ag.buildRequest(sysprompt, msg)
{Role: "system", Content: sysprompt},
{Role: "user", Content: msg},
}
agentChat := &models.ChatBody{
Model: ag.cfg.CurrentModel,
Stream: true,
Messages: agentConvo,
}
b, err := json.Marshal(agentChat)
if err != nil { if err != nil {
ag.log.Error("failed to form agent msg", "error", err)
return nil, err return nil, err
} }
return bytes.NewReader(b), nil return bytes.NewReader(b), nil
} }
// buildRequest creates the appropriate LLM request based on the current API endpoint.
func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) {
api := ag.cfg.CurrentAPI
model := ag.cfg.CurrentModel
messages := []models.RoleMsg{
{Role: "system", Content: sysprompt},
{Role: "user", Content: msg},
}
// Determine API type
isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api)
ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter)
// Build prompt for completion endpoints
if isCompletion {
var sb strings.Builder
for _, m := range messages {
sb.WriteString(m.ToPrompt())
sb.WriteString("\n")
}
prompt := strings.TrimSpace(sb.String())
if isDeepSeek {
// DeepSeek completion
req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{})
req.Stream = false // Agents don't need streaming
return json.Marshal(req)
} else if isOpenRouter {
// OpenRouter completion
req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{})
req.Stream = false // Agents don't need streaming
return json.Marshal(req)
} else {
// Assume llama.cpp completion
req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{})
req.Stream = false // Agents don't need streaming
return json.Marshal(req)
}
}
// Chat completions endpoints
if isChat || !isCompletion {
chatBody := &models.ChatBody{
Model: model,
Stream: false, // Agents don't need streaming
Messages: messages,
}
if isDeepSeek {
// DeepSeek chat
req := models.NewDSChatReq(*chatBody)
return json.Marshal(req)
} else if isOpenRouter {
// OpenRouter chat
req := models.NewOpenRouterChatReq(*chatBody, defaultProps)
return json.Marshal(req)
} else {
// Assume llama.cpp chat (OpenAI format)
req := models.OpenAIReq{
ChatBody: chatBody,
Tools: nil,
}
return json.Marshal(req)
}
}
// Fallback (should not reach here)
ag.log.Warn("unknown API, using default chat completions format", "api", api)
chatBody := &models.ChatBody{
Model: model,
Stream: false, // Agents don't need streaming
Messages: messages,
}
return json.Marshal(chatBody)
}
func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) { func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, body) // Read the body for debugging (but we need to recreate it for the request)
bodyBytes, err := io.ReadAll(body)
if err != nil { if err != nil {
ag.log.Error("llamacpp api", "error", err) ag.log.Error("failed to read request body", "error", err)
return nil, err
}
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
if err != nil {
ag.log.Error("failed to create request", "error", err)
return nil, err return nil, err
} }
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+ag.getToken()) req.Header.Add("Authorization", "Bearer "+ag.getToken())
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
ag.log.Error("llamacpp api", "error", err) ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
return io.ReadAll(resp.Body)
responseBytes, err := io.ReadAll(resp.Body)
if err != nil {
ag.log.Error("failed to read response", "error", err)
return nil, err
}
if resp.StatusCode >= 400 {
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
}
// Parse response and extract text content
text, err := extractTextFromResponse(responseBytes)
if err != nil {
ag.log.Error("failed to extract text from response", "error", err, "response_preview", string(responseBytes[:min(len(responseBytes), 500)]))
// Return raw response as fallback
return responseBytes, nil
}
return []byte(text), nil
}
// extractTextFromResponse parses common LLM response formats and extracts the text content.
func extractTextFromResponse(data []byte) (string, error) {
// Try to parse as generic JSON first
var genericResp map[string]interface{}
if err := json.Unmarshal(data, &genericResp); err != nil {
// Not JSON, return as string
return string(data), nil
}
// Check for OpenAI chat completion format
if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 {
if firstChoice, ok := choices[0].(map[string]interface{}); ok {
// Chat completion: choices[0].message.content
if message, ok := firstChoice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
return content, nil
}
}
// Completion: choices[0].text
if text, ok := firstChoice["text"].(string); ok {
return text, nil
}
// Delta format for streaming (should not happen with stream: false)
if delta, ok := firstChoice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
return content, nil
}
}
}
}
// Check for llama.cpp completion format
if content, ok := genericResp["content"].(string); ok {
return content, nil
}
// Unknown format, return pretty-printed JSON
prettyJSON, err := json.MarshalIndent(genericResp, "", " ")
if err != nil {
return string(data), nil
}
return string(prettyJSON), nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
} }

72
bot.go
View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv"
"gf-lt/config" "gf-lt/config"
"gf-lt/extra" "gf-lt/extra"
"gf-lt/models" "gf-lt/models"
@@ -659,14 +660,75 @@ func cleanChatBody() {
} }
} }
// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings.
func convertJSONToMapStringString(jsonStr string) (map[string]string, error) {
var raw map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
return nil, err
}
result := make(map[string]string, len(raw))
for k, v := range raw {
switch val := v.(type) {
case string:
result[k] = val
case float64:
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
case int, int64, int32:
// json.Unmarshal converts numbers to float64, but handle other integer types if they appear
result[k] = fmt.Sprintf("%v", val)
case bool:
result[k] = strconv.FormatBool(val)
case nil:
result[k] = ""
default:
result[k] = fmt.Sprintf("%v", val)
}
}
return result, nil
}
// unmarshalFuncCall unmarshals a JSON tool call, converting numeric arguments to strings.
func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) {
type tempFuncCall struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Args map[string]interface{} `json:"args"`
}
var temp tempFuncCall
if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil {
return nil, err
}
fc := &models.FuncCall{
ID: temp.ID,
Name: temp.Name,
Args: make(map[string]string, len(temp.Args)),
}
for k, v := range temp.Args {
switch val := v.(type) {
case string:
fc.Args[k] = val
case float64:
fc.Args[k] = strconv.FormatFloat(val, 'f', -1, 64)
case int, int64, int32:
fc.Args[k] = fmt.Sprintf("%v", val)
case bool:
fc.Args[k] = strconv.FormatBool(val)
case nil:
fc.Args[k] = ""
default:
fc.Args[k] = fmt.Sprintf("%v", val)
}
}
return fc, nil
}
func findCall(msg, toolCall string, tv *tview.TextView) { func findCall(msg, toolCall string, tv *tview.TextView) {
fc := &models.FuncCall{} fc := &models.FuncCall{}
if toolCall != "" { if toolCall != "" {
// HTML-decode the tool call string to handle encoded characters like &lt; -> <= // HTML-decode the tool call string to handle encoded characters like &lt; -> <=
decodedToolCall := html.UnescapeString(toolCall) decodedToolCall := html.UnescapeString(toolCall)
openAIToolMap := make(map[string]string) openAIToolMap, err := convertJSONToMapStringString(decodedToolCall)
// respect tool call if err != nil {
if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil {
logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err) logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err)
// Send error response to LLM so it can retry or handle the error // Send error response to LLM so it can retry or handle the error
toolResponseMsg := models.RoleMsg{ toolResponseMsg := models.RoleMsg{
@@ -700,7 +762,9 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
// HTML-decode the JSON string to handle encoded characters like &lt; -> <= // HTML-decode the JSON string to handle encoded characters like &lt; -> <=
decodedJsStr := html.UnescapeString(jsStr) decodedJsStr := html.UnescapeString(jsStr)
if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil { var err error
fc, err = unmarshalFuncCall(decodedJsStr)
if err != nil {
logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr) logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr)
// Send error response to LLM so it can retry or handle the error // Send error response to LLM so it can retry or handle the error
toolResponseMsg := models.RoleMsg{ toolResponseMsg := models.RoleMsg{

View File

@@ -152,4 +152,138 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
} }
}) })
} }
}
func TestUnmarshalFuncCall(t *testing.T) {
tests := []struct {
name string
jsonStr string
want *models.FuncCall
wantErr bool
}{
{
name: "simple websearch with numeric limit",
jsonStr: `{"name": "websearch", "args": {"query": "current weather in London", "limit": 3}}`,
want: &models.FuncCall{
Name: "websearch",
Args: map[string]string{"query": "current weather in London", "limit": "3"},
},
wantErr: false,
},
{
name: "string limit",
jsonStr: `{"name": "websearch", "args": {"query": "test", "limit": "5"}}`,
want: &models.FuncCall{
Name: "websearch",
Args: map[string]string{"query": "test", "limit": "5"},
},
wantErr: false,
},
{
name: "boolean arg",
jsonStr: `{"name": "test", "args": {"flag": true}}`,
want: &models.FuncCall{
Name: "test",
Args: map[string]string{"flag": "true"},
},
wantErr: false,
},
{
name: "null arg",
jsonStr: `{"name": "test", "args": {"opt": null}}`,
want: &models.FuncCall{
Name: "test",
Args: map[string]string{"opt": ""},
},
wantErr: false,
},
{
name: "float arg",
jsonStr: `{"name": "test", "args": {"ratio": 0.5}}`,
want: &models.FuncCall{
Name: "test",
Args: map[string]string{"ratio": "0.5"},
},
wantErr: false,
},
{
name: "invalid JSON",
jsonStr: `{invalid}`,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := unmarshalFuncCall(tt.jsonStr)
if (err != nil) != tt.wantErr {
t.Errorf("unmarshalFuncCall() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if got.Name != tt.want.Name {
t.Errorf("unmarshalFuncCall() name = %v, want %v", got.Name, tt.want.Name)
}
if len(got.Args) != len(tt.want.Args) {
t.Errorf("unmarshalFuncCall() args length = %v, want %v", len(got.Args), len(tt.want.Args))
}
for k, v := range tt.want.Args {
if got.Args[k] != v {
t.Errorf("unmarshalFuncCall() args[%v] = %v, want %v", k, got.Args[k], v)
}
}
})
}
}
func TestConvertJSONToMapStringString(t *testing.T) {
tests := []struct {
name string
jsonStr string
want map[string]string
wantErr bool
}{
{
name: "simple map",
jsonStr: `{"query": "weather", "limit": 5}`,
want: map[string]string{"query": "weather", "limit": "5"},
wantErr: false,
},
{
name: "boolean and null",
jsonStr: `{"flag": true, "opt": null}`,
want: map[string]string{"flag": "true", "opt": ""},
wantErr: false,
},
{
name: "invalid JSON",
jsonStr: `{invalid`,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := convertJSONToMapStringString(tt.jsonStr)
if (err != nil) != tt.wantErr {
t.Errorf("convertJSONToMapStringString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if len(got) != len(tt.want) {
t.Errorf("convertJSONToMapStringString() length = %v, want %v", len(got), len(tt.want))
}
for k, v := range tt.want {
if got[k] != v {
t.Errorf("convertJSONToMapStringString()[%v] = %v, want %v", k, got[k], v)
}
}
})
}
} }

View File

@@ -13,6 +13,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@@ -126,7 +127,9 @@ under the topic: Adam's number is stored:
</example_response> </example_response>
After that you are free to respond to the user. After that you are free to respond to the user.
` `
basicCard = &models.CharCard{ webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.`
readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.`
basicCard = &models.CharCard{
SysPrompt: basicSysMsg, SysPrompt: basicSysMsg,
FirstMsg: defaultFirstMsg, FirstMsg: defaultFirstMsg,
Role: "", Role: "",
@@ -141,8 +144,43 @@ After that you are free to respond to the user.
// sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg} // sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg}
sysMap = map[string]*models.CharCard{"basic_sys": basicCard} sysMap = map[string]*models.CharCard{"basic_sys": basicCard}
sysLabels = []string{"basic_sys"} sysLabels = []string{"basic_sys"}
webAgentClient *agent.AgentClient
webAgentClientOnce sync.Once
webAgentsOnce sync.Once
) )
// getWebAgentClient returns a singleton AgentClient for web agents.
func getWebAgentClient() *agent.AgentClient {
webAgentClientOnce.Do(func() {
if cfg == nil {
panic("cfg not initialized")
}
if logger == nil {
panic("logger not initialized")
}
getToken := func() string {
if chunkParser == nil {
return ""
}
return chunkParser.GetToken()
}
webAgentClient = agent.NewAgentClient(cfg, *logger, getToken)
})
return webAgentClient
}
// registerWebAgents registers WebAgentB instances for websearch and read_url tools.
func registerWebAgents() {
webAgentsOnce.Do(func() {
client := getWebAgentClient()
// Register websearch agent
agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt))
// Register read_url agent
agent.Register("read_url", agent.NewWebAgentB(client, readURLSysPrompt))
})
}
// web search (depends on extra server) // web search (depends on extra server)
func websearch(args map[string]string) []byte { func websearch(args map[string]string) []byte {
// make http request return bytes // make http request return bytes
@@ -597,7 +635,6 @@ var globalTodoList = TodoList{
Items: []TodoItem{}, Items: []TodoItem{},
} }
// Todo Management Tools // Todo Management Tools
func todoCreate(args map[string]string) []byte { func todoCreate(args map[string]string) []byte {
task, ok := args["task"] task, ok := args["task"]
@@ -851,6 +888,7 @@ var fnMap = map[string]fnSig{
// callToolWithAgent calls the tool and applies any registered agent. // callToolWithAgent calls the tool and applies any registered agent.
func callToolWithAgent(name string, args map[string]string) []byte { func callToolWithAgent(name string, args map[string]string) []byte {
registerWebAgents()
f, ok := fnMap[name] f, ok := fnMap[name]
if !ok { if !ok {
return []byte(fmt.Sprintf("tool %s not found", name)) return []byte(fmt.Sprintf("tool %s not found", name))
@@ -862,16 +900,6 @@ func callToolWithAgent(name string, args map[string]string) []byte {
return raw return raw
} }
// registerDefaultAgents registers default agents for formatting.
func registerDefaultAgents() {
agent.Register("websearch", agent.DefaultFormatter("websearch"))
agent.Register("read_url", agent.DefaultFormatter("read_url"))
}
func init() {
registerDefaultAgents()
}
// openai style def // openai style def
var baseTools = []models.Tool{ var baseTools = []models.Tool{
// websearch // websearch