Enha: agent request builder
This commit is contained in:
@@ -33,3 +33,13 @@ func RegisterB(toolName string, a AgenterB) {
|
|||||||
func RegisterA(toolNames []string, a AgenterA) {
|
func RegisterA(toolNames []string, a AgenterA) {
|
||||||
RegistryA[a] = toolNames
|
RegistryA[a] = toolNames
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get returns the agent registered for the given tool name, or nil if none.
|
||||||
|
func Get(toolName string) AgenterB {
|
||||||
|
return RegistryB[toolName]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register is a convenience wrapper for RegisterB.
|
||||||
|
func Register(toolName string, a AgenterB) {
|
||||||
|
RegisterB(toolName, a)
|
||||||
|
}
|
||||||
|
|||||||
194
agent/request.go
194
agent/request.go
@@ -3,15 +3,32 @@ package agent
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"gf-lt/config"
|
"gf-lt/config"
|
||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpClient = &http.Client{}
|
var httpClient = &http.Client{}
|
||||||
|
|
||||||
|
var defaultProps = map[string]float32{
|
||||||
|
"temperature": 0.8,
|
||||||
|
"dry_multiplier": 0.0,
|
||||||
|
"min_p": 0.05,
|
||||||
|
"n_predict": -1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) {
|
||||||
|
isCompletion = strings.Contains(api, "/completion") && !strings.Contains(api, "/chat/completions")
|
||||||
|
isChat = strings.Contains(api, "/chat/completions")
|
||||||
|
isDeepSeek = strings.Contains(api, "deepseek.com")
|
||||||
|
isOpenRouter = strings.Contains(api, "openrouter.ai")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
type AgentClient struct {
|
type AgentClient struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
getToken func() string
|
getToken func() string
|
||||||
@@ -31,38 +48,185 @@ func (ag *AgentClient) Log() *slog.Logger {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) {
|
func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) {
|
||||||
agentConvo := []models.RoleMsg{
|
b, err := ag.buildRequest(sysprompt, msg)
|
||||||
{Role: "system", Content: sysprompt},
|
|
||||||
{Role: "user", Content: msg},
|
|
||||||
}
|
|
||||||
agentChat := &models.ChatBody{
|
|
||||||
Model: ag.cfg.CurrentModel,
|
|
||||||
Stream: true,
|
|
||||||
Messages: agentConvo,
|
|
||||||
}
|
|
||||||
b, err := json.Marshal(agentChat)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("failed to form agent msg", "error", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return bytes.NewReader(b), nil
|
return bytes.NewReader(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildRequest creates the appropriate LLM request based on the current API endpoint.
|
||||||
|
func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) {
|
||||||
|
api := ag.cfg.CurrentAPI
|
||||||
|
model := ag.cfg.CurrentModel
|
||||||
|
messages := []models.RoleMsg{
|
||||||
|
{Role: "system", Content: sysprompt},
|
||||||
|
{Role: "user", Content: msg},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine API type
|
||||||
|
isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api)
|
||||||
|
ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter)
|
||||||
|
|
||||||
|
// Build prompt for completion endpoints
|
||||||
|
if isCompletion {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, m := range messages {
|
||||||
|
sb.WriteString(m.ToPrompt())
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
prompt := strings.TrimSpace(sb.String())
|
||||||
|
|
||||||
|
if isDeepSeek {
|
||||||
|
// DeepSeek completion
|
||||||
|
req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{})
|
||||||
|
req.Stream = false // Agents don't need streaming
|
||||||
|
return json.Marshal(req)
|
||||||
|
} else if isOpenRouter {
|
||||||
|
// OpenRouter completion
|
||||||
|
req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{})
|
||||||
|
req.Stream = false // Agents don't need streaming
|
||||||
|
return json.Marshal(req)
|
||||||
|
} else {
|
||||||
|
// Assume llama.cpp completion
|
||||||
|
req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{})
|
||||||
|
req.Stream = false // Agents don't need streaming
|
||||||
|
return json.Marshal(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat completions endpoints
|
||||||
|
if isChat || !isCompletion {
|
||||||
|
chatBody := &models.ChatBody{
|
||||||
|
Model: model,
|
||||||
|
Stream: false, // Agents don't need streaming
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isDeepSeek {
|
||||||
|
// DeepSeek chat
|
||||||
|
req := models.NewDSChatReq(*chatBody)
|
||||||
|
return json.Marshal(req)
|
||||||
|
} else if isOpenRouter {
|
||||||
|
// OpenRouter chat
|
||||||
|
req := models.NewOpenRouterChatReq(*chatBody, defaultProps)
|
||||||
|
return json.Marshal(req)
|
||||||
|
} else {
|
||||||
|
// Assume llama.cpp chat (OpenAI format)
|
||||||
|
req := models.OpenAIReq{
|
||||||
|
ChatBody: chatBody,
|
||||||
|
Tools: nil,
|
||||||
|
}
|
||||||
|
return json.Marshal(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback (should not reach here)
|
||||||
|
ag.log.Warn("unknown API, using default chat completions format", "api", api)
|
||||||
|
chatBody := &models.ChatBody{
|
||||||
|
Model: model,
|
||||||
|
Stream: false, // Agents don't need streaming
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
return json.Marshal(chatBody)
|
||||||
|
}
|
||||||
|
|
||||||
func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
||||||
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, body)
|
// Read the body for debugging (but we need to recreate it for the request)
|
||||||
|
bodyBytes, err := io.ReadAll(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("llamacpp api", "error", err)
|
ag.log.Error("failed to read request body", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
ag.log.Error("failed to create request", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Add("Accept", "application/json")
|
req.Header.Add("Accept", "application/json")
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
|
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("llamacpp api", "error", err)
|
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
return io.ReadAll(resp.Body)
|
|
||||||
|
responseBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
ag.log.Error("failed to read response", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
|
||||||
|
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response and extract text content
|
||||||
|
text, err := extractTextFromResponse(responseBytes)
|
||||||
|
if err != nil {
|
||||||
|
ag.log.Error("failed to extract text from response", "error", err, "response_preview", string(responseBytes[:min(len(responseBytes), 500)]))
|
||||||
|
// Return raw response as fallback
|
||||||
|
return responseBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(text), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTextFromResponse parses common LLM response formats and extracts the text content.
|
||||||
|
func extractTextFromResponse(data []byte) (string, error) {
|
||||||
|
// Try to parse as generic JSON first
|
||||||
|
var genericResp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &genericResp); err != nil {
|
||||||
|
// Not JSON, return as string
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for OpenAI chat completion format
|
||||||
|
if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||||
|
if firstChoice, ok := choices[0].(map[string]interface{}); ok {
|
||||||
|
// Chat completion: choices[0].message.content
|
||||||
|
if message, ok := firstChoice["message"].(map[string]interface{}); ok {
|
||||||
|
if content, ok := message["content"].(string); ok {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Completion: choices[0].text
|
||||||
|
if text, ok := firstChoice["text"].(string); ok {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
// Delta format for streaming (should not happen with stream: false)
|
||||||
|
if delta, ok := firstChoice["delta"].(map[string]interface{}); ok {
|
||||||
|
if content, ok := delta["content"].(string); ok {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for llama.cpp completion format
|
||||||
|
if content, ok := genericResp["content"].(string); ok {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown format, return pretty-printed JSON
|
||||||
|
prettyJSON, err := json.MarshalIndent(genericResp, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
return string(prettyJSON), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
}
|
}
|
||||||
|
|||||||
72
bot.go
72
bot.go
@@ -6,6 +6,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"gf-lt/config"
|
"gf-lt/config"
|
||||||
"gf-lt/extra"
|
"gf-lt/extra"
|
||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
@@ -659,14 +660,75 @@ func cleanChatBody() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings.
|
||||||
|
func convertJSONToMapStringString(jsonStr string) (map[string]string, error) {
|
||||||
|
var raw map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := make(map[string]string, len(raw))
|
||||||
|
for k, v := range raw {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
result[k] = val
|
||||||
|
case float64:
|
||||||
|
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
|
||||||
|
case int, int64, int32:
|
||||||
|
// json.Unmarshal converts numbers to float64, but handle other integer types if they appear
|
||||||
|
result[k] = fmt.Sprintf("%v", val)
|
||||||
|
case bool:
|
||||||
|
result[k] = strconv.FormatBool(val)
|
||||||
|
case nil:
|
||||||
|
result[k] = ""
|
||||||
|
default:
|
||||||
|
result[k] = fmt.Sprintf("%v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalFuncCall unmarshals a JSON tool call, converting numeric arguments to strings.
|
||||||
|
func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) {
|
||||||
|
type tempFuncCall struct {
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Args map[string]interface{} `json:"args"`
|
||||||
|
}
|
||||||
|
var temp tempFuncCall
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fc := &models.FuncCall{
|
||||||
|
ID: temp.ID,
|
||||||
|
Name: temp.Name,
|
||||||
|
Args: make(map[string]string, len(temp.Args)),
|
||||||
|
}
|
||||||
|
for k, v := range temp.Args {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
fc.Args[k] = val
|
||||||
|
case float64:
|
||||||
|
fc.Args[k] = strconv.FormatFloat(val, 'f', -1, 64)
|
||||||
|
case int, int64, int32:
|
||||||
|
fc.Args[k] = fmt.Sprintf("%v", val)
|
||||||
|
case bool:
|
||||||
|
fc.Args[k] = strconv.FormatBool(val)
|
||||||
|
case nil:
|
||||||
|
fc.Args[k] = ""
|
||||||
|
default:
|
||||||
|
fc.Args[k] = fmt.Sprintf("%v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fc, nil
|
||||||
|
}
|
||||||
|
|
||||||
func findCall(msg, toolCall string, tv *tview.TextView) {
|
func findCall(msg, toolCall string, tv *tview.TextView) {
|
||||||
fc := &models.FuncCall{}
|
fc := &models.FuncCall{}
|
||||||
if toolCall != "" {
|
if toolCall != "" {
|
||||||
// HTML-decode the tool call string to handle encoded characters like < -> <=
|
// HTML-decode the tool call string to handle encoded characters like < -> <=
|
||||||
decodedToolCall := html.UnescapeString(toolCall)
|
decodedToolCall := html.UnescapeString(toolCall)
|
||||||
openAIToolMap := make(map[string]string)
|
openAIToolMap, err := convertJSONToMapStringString(decodedToolCall)
|
||||||
// respect tool call
|
if err != nil {
|
||||||
if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil {
|
|
||||||
logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err)
|
logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err)
|
||||||
// Send error response to LLM so it can retry or handle the error
|
// Send error response to LLM so it can retry or handle the error
|
||||||
toolResponseMsg := models.RoleMsg{
|
toolResponseMsg := models.RoleMsg{
|
||||||
@@ -700,7 +762,9 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
|
|||||||
jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
|
jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
|
||||||
// HTML-decode the JSON string to handle encoded characters like < -> <=
|
// HTML-decode the JSON string to handle encoded characters like < -> <=
|
||||||
decodedJsStr := html.UnescapeString(jsStr)
|
decodedJsStr := html.UnescapeString(jsStr)
|
||||||
if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil {
|
var err error
|
||||||
|
fc, err = unmarshalFuncCall(decodedJsStr)
|
||||||
|
if err != nil {
|
||||||
logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr)
|
logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr)
|
||||||
// Send error response to LLM so it can retry or handle the error
|
// Send error response to LLM so it can retry or handle the error
|
||||||
toolResponseMsg := models.RoleMsg{
|
toolResponseMsg := models.RoleMsg{
|
||||||
|
|||||||
134
bot_test.go
134
bot_test.go
@@ -152,4 +152,138 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalFuncCall(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
jsonStr string
|
||||||
|
want *models.FuncCall
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple websearch with numeric limit",
|
||||||
|
jsonStr: `{"name": "websearch", "args": {"query": "current weather in London", "limit": 3}}`,
|
||||||
|
want: &models.FuncCall{
|
||||||
|
Name: "websearch",
|
||||||
|
Args: map[string]string{"query": "current weather in London", "limit": "3"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string limit",
|
||||||
|
jsonStr: `{"name": "websearch", "args": {"query": "test", "limit": "5"}}`,
|
||||||
|
want: &models.FuncCall{
|
||||||
|
Name: "websearch",
|
||||||
|
Args: map[string]string{"query": "test", "limit": "5"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean arg",
|
||||||
|
jsonStr: `{"name": "test", "args": {"flag": true}}`,
|
||||||
|
want: &models.FuncCall{
|
||||||
|
Name: "test",
|
||||||
|
Args: map[string]string{"flag": "true"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null arg",
|
||||||
|
jsonStr: `{"name": "test", "args": {"opt": null}}`,
|
||||||
|
want: &models.FuncCall{
|
||||||
|
Name: "test",
|
||||||
|
Args: map[string]string{"opt": ""},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "float arg",
|
||||||
|
jsonStr: `{"name": "test", "args": {"ratio": 0.5}}`,
|
||||||
|
want: &models.FuncCall{
|
||||||
|
Name: "test",
|
||||||
|
Args: map[string]string{"ratio": "0.5"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON",
|
||||||
|
jsonStr: `{invalid}`,
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := unmarshalFuncCall(tt.jsonStr)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("unmarshalFuncCall() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got.Name != tt.want.Name {
|
||||||
|
t.Errorf("unmarshalFuncCall() name = %v, want %v", got.Name, tt.want.Name)
|
||||||
|
}
|
||||||
|
if len(got.Args) != len(tt.want.Args) {
|
||||||
|
t.Errorf("unmarshalFuncCall() args length = %v, want %v", len(got.Args), len(tt.want.Args))
|
||||||
|
}
|
||||||
|
for k, v := range tt.want.Args {
|
||||||
|
if got.Args[k] != v {
|
||||||
|
t.Errorf("unmarshalFuncCall() args[%v] = %v, want %v", k, got.Args[k], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertJSONToMapStringString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
jsonStr string
|
||||||
|
want map[string]string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple map",
|
||||||
|
jsonStr: `{"query": "weather", "limit": 5}`,
|
||||||
|
want: map[string]string{"query": "weather", "limit": "5"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean and null",
|
||||||
|
jsonStr: `{"flag": true, "opt": null}`,
|
||||||
|
want: map[string]string{"flag": "true", "opt": ""},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON",
|
||||||
|
jsonStr: `{invalid`,
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := convertJSONToMapStringString(tt.jsonStr)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("convertJSONToMapStringString() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Errorf("convertJSONToMapStringString() length = %v, want %v", len(got), len(tt.want))
|
||||||
|
}
|
||||||
|
for k, v := range tt.want {
|
||||||
|
if got[k] != v {
|
||||||
|
t.Errorf("convertJSONToMapStringString()[%v] = %v, want %v", k, got[k], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
52
tools.go
52
tools.go
@@ -13,6 +13,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -126,7 +127,9 @@ under the topic: Adam's number is stored:
|
|||||||
</example_response>
|
</example_response>
|
||||||
After that you are free to respond to the user.
|
After that you are free to respond to the user.
|
||||||
`
|
`
|
||||||
basicCard = &models.CharCard{
|
webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.`
|
||||||
|
readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.`
|
||||||
|
basicCard = &models.CharCard{
|
||||||
SysPrompt: basicSysMsg,
|
SysPrompt: basicSysMsg,
|
||||||
FirstMsg: defaultFirstMsg,
|
FirstMsg: defaultFirstMsg,
|
||||||
Role: "",
|
Role: "",
|
||||||
@@ -141,8 +144,43 @@ After that you are free to respond to the user.
|
|||||||
// sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg}
|
// sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg}
|
||||||
sysMap = map[string]*models.CharCard{"basic_sys": basicCard}
|
sysMap = map[string]*models.CharCard{"basic_sys": basicCard}
|
||||||
sysLabels = []string{"basic_sys"}
|
sysLabels = []string{"basic_sys"}
|
||||||
|
|
||||||
|
webAgentClient *agent.AgentClient
|
||||||
|
webAgentClientOnce sync.Once
|
||||||
|
webAgentsOnce sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getWebAgentClient returns a singleton AgentClient for web agents.
|
||||||
|
func getWebAgentClient() *agent.AgentClient {
|
||||||
|
webAgentClientOnce.Do(func() {
|
||||||
|
if cfg == nil {
|
||||||
|
panic("cfg not initialized")
|
||||||
|
}
|
||||||
|
if logger == nil {
|
||||||
|
panic("logger not initialized")
|
||||||
|
}
|
||||||
|
getToken := func() string {
|
||||||
|
if chunkParser == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return chunkParser.GetToken()
|
||||||
|
}
|
||||||
|
webAgentClient = agent.NewAgentClient(cfg, *logger, getToken)
|
||||||
|
})
|
||||||
|
return webAgentClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerWebAgents registers WebAgentB instances for websearch and read_url tools.
|
||||||
|
func registerWebAgents() {
|
||||||
|
webAgentsOnce.Do(func() {
|
||||||
|
client := getWebAgentClient()
|
||||||
|
// Register websearch agent
|
||||||
|
agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt))
|
||||||
|
// Register read_url agent
|
||||||
|
agent.Register("read_url", agent.NewWebAgentB(client, readURLSysPrompt))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// web search (depends on extra server)
|
// web search (depends on extra server)
|
||||||
func websearch(args map[string]string) []byte {
|
func websearch(args map[string]string) []byte {
|
||||||
// make http request return bytes
|
// make http request return bytes
|
||||||
@@ -597,7 +635,6 @@ var globalTodoList = TodoList{
|
|||||||
Items: []TodoItem{},
|
Items: []TodoItem{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Todo Management Tools
|
// Todo Management Tools
|
||||||
func todoCreate(args map[string]string) []byte {
|
func todoCreate(args map[string]string) []byte {
|
||||||
task, ok := args["task"]
|
task, ok := args["task"]
|
||||||
@@ -851,6 +888,7 @@ var fnMap = map[string]fnSig{
|
|||||||
|
|
||||||
// callToolWithAgent calls the tool and applies any registered agent.
|
// callToolWithAgent calls the tool and applies any registered agent.
|
||||||
func callToolWithAgent(name string, args map[string]string) []byte {
|
func callToolWithAgent(name string, args map[string]string) []byte {
|
||||||
|
registerWebAgents()
|
||||||
f, ok := fnMap[name]
|
f, ok := fnMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
return []byte(fmt.Sprintf("tool %s not found", name))
|
return []byte(fmt.Sprintf("tool %s not found", name))
|
||||||
@@ -862,16 +900,6 @@ func callToolWithAgent(name string, args map[string]string) []byte {
|
|||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerDefaultAgents registers default agents for formatting.
|
|
||||||
func registerDefaultAgents() {
|
|
||||||
agent.Register("websearch", agent.DefaultFormatter("websearch"))
|
|
||||||
agent.Register("read_url", agent.DefaultFormatter("read_url"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
registerDefaultAgents()
|
|
||||||
}
|
|
||||||
|
|
||||||
// openai style def
|
// openai style def
|
||||||
var baseTools = []models.Tool{
|
var baseTools = []models.Tool{
|
||||||
// websearch
|
// websearch
|
||||||
|
|||||||
Reference in New Issue
Block a user