From 8699b1a84e74e5e4969ea4f79b2602a244004829 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Fri, 5 Sep 2025 14:03:17 +0300 Subject: [PATCH] Feat: llm tool use --- main.go | 23 +++++++++++- models/models.go | 45 ++++++++++++++++------- parser.go | 95 ++++++++++++++++++++++++++++++++++-------------- 3 files changed, 119 insertions(+), 44 deletions(-) diff --git a/main.go b/main.go index 884f88e..ef44924 100644 --- a/main.go +++ b/main.go @@ -4,8 +4,8 @@ import ( "bytes" "encoding/json" "errors" - "flag" "fmt" + "flag" "io" "log/slog" "net/http" @@ -194,7 +194,26 @@ func runBench(questions []models.Question) ([]models.Answer, error) { logger.Error("failed to parse llm response", "error", err) continue } - a := models.Answer{Q: q, Answer: respText, Model: "default"} // model name from resp? + + // Check if the response indicates tool usage + var toolCall models.ToolCallInfo + if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") { + // Extract tool name from the marker + toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:") + toolCall = models.ToolCallInfo{ + Name: toolName, + // Arguments would need to be parsed from the actual response + } + // Remove the marker from the response text + respText = fmt.Sprintf("Used tool: %s", toolName) + } + + a := models.Answer{ + Q: q, + Answer: respText, + Model: "default", // model name from resp? + ToolCall: toolCall, + } answers = append(answers, a) } return answers, nil diff --git a/models/models.go b/models/models.go index cef8242..2a1d782 100644 --- a/models/models.go +++ b/models/models.go @@ -7,16 +7,15 @@ type Question struct { } type Answer struct { - Q Question - Answer string `json:"answer"` - Model string `json:"model"` - // resp time? - ToolCall ToolCall `json:"tool_call"` + Q Question + Answer string `json:"answer"` + Model string `json:"model"` + ToolCall ToolCallInfo `json:"tool_call,omitempty"` } -type ToolCall struct { - Name string `json:"name"` - Args map[string]ToolArgProps `json:"args"` +type ToolCallInfo struct { + Name string `json:"name"` + Args map[string]string `json:"args"` } // type OpenRouterResp struct { @@ -44,13 +43,31 @@ type ToolCall struct { // } `json:"usage"` // } +type ToolCallFunction struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function ToolCallFunction `json:"function"` +} + +type DSRespChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + } `json:"message,omitempty"` +} + type DSResp struct { - ID string `json:"id"` - Choices []struct { - Text string `json:"text"` - Index int `json:"index"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` + ID string `json:"id"` + Choices []DSRespChoice `json:"choices"` Created int `json:"created"` Model string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` diff --git a/parser.go b/parser.go index e27806a..d4df06e 100644 --- a/parser.go +++ b/parser.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "grailbench/models" "io" "log/slog" @@ -35,26 +36,43 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) { err := errors.New("empty choices in dsResp") return "", err } - text := dsResp.Choices[0].Text - return text, nil + + // Check if the response contains tool calls + choice := dsResp.Choices[0] + + // Handle response with message field (OpenAI format) + if choice.Message.Role != "" { + if len(choice.Message.ToolCalls) > 0 { + // Handle tool call response + toolCall := choice.Message.ToolCalls[0] + // Return a special marker indicating tool usage + return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil + } + // Regular text response + return choice.Message.Content, nil + } + + // Handle response with text field (legacy format) + return choice.Text, nil } func (p *deepSeekParser) MakePayload(prompt string) io.Reader { payload := struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Echo bool `json:"echo"` - FrequencyPenalty float64 `json:"frequency_penalty"` - Logprobs int `json:"logprobs"` - MaxTokens int `json:"max_tokens"` - PresencePenalty float64 `json:"presence_penalty"` - Stop interface{} `json:"stop"` - Stream bool `json:"stream"` - StreamOptions interface{} `json:"stream_options"` - Suffix interface{} `json:"suffix"` - Temperature float64 `json:"temperature"` - NProbs int `json:"n_probs"` - TopP float64 `json:"top_p"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Echo bool `json:"echo"` + FrequencyPenalty float64 `json:"frequency_penalty"` + Logprobs int `json:"logprobs"` + MaxTokens int `json:"max_tokens"` + PresencePenalty float64 `json:"presence_penalty"` + Stop interface{} `json:"stop"` + Stream bool `json:"stream"` + StreamOptions interface{} `json:"stream_options"` + Suffix interface{} `json:"suffix"` + Temperature float64 `json:"temperature"` + NProbs int `json:"n_probs"` + TopP float64 `json:"top_p"` + Tools []models.Tool `json:"tools,omitempty"` }{ Model: "deepseek-chat", Prompt: prompt, @@ -70,6 +88,7 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader { Temperature: 1, NProbs: 10, TopP: 1, + Tools: baseTools, // Include the tools in the request } b, err := json.Marshal(payload) if err != nil { @@ -100,14 +119,15 @@ func (p *lcpRespParser) ParseBytes(body []byte) (string, error) { func (p *lcpRespParser) MakePayload(prompt string) io.Reader { payload := struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - FrequencyPenalty float64 `json:"frequency_penalty"` - MaxTokens int `json:"max_tokens"` - Stop []string `json:"stop"` - Stream bool `json:"stream"` - Temperature float64 `json:"temperature"` - TopP float64 `json:"top_p"` + Model string `json:"model"` + Prompt string `json:"prompt"` + FrequencyPenalty float64 `json:"frequency_penalty"` + MaxTokens int `json:"max_tokens"` + Stop []string `json:"stop"` + Stream bool `json:"stream"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + Tools []models.Tool `json:"tools,omitempty"` }{ Model: "local-model", Prompt: prompt, @@ -117,6 +137,7 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader { Stream: false, Temperature: 0.4, TopP: 1, + Tools: baseTools, // Include tools (though local model may not support them) } b, err := json.Marshal(payload) @@ -152,8 +173,24 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) { err := errors.New("empty choices in resp") return "", err } - text := resp.Choices[0].Text - return text, nil + + // Check if the response contains tool calls + choice := resp.Choices[0] + + // Handle response with message field (OpenAI format) + if choice.Message.Role != "" { + if len(choice.Message.ToolCalls) > 0 { + // Handle tool call response + toolCall := choice.Message.ToolCalls[0] + // Return a special marker indicating tool usage + return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil + } + // Regular text response + return choice.Message.Content, nil + } + + // Handle response with text field (legacy format) + return choice.Text, nil } func (p *openRouterParser) MakePayload(prompt string) io.Reader { @@ -163,11 +200,13 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader { // Get next model index using atomic addition for thread safety p.modelIndex++ payload := struct { - Model string `json:"model"` - Prompt string `json:"prompt"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Tools []models.Tool `json:"tools,omitempty"` }{ Model: model, Prompt: prompt, + Tools: baseTools, // Include the tools in the request } b, err := json.Marshal(payload)