Feat: llm tool use

This commit is contained in:
Grail Finder
2025-09-05 14:03:17 +03:00
parent 8b88d2d824
commit 8699b1a84e
3 changed files with 119 additions and 44 deletions

23
main.go
View File

@@ -4,8 +4,8 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"flag"
"fmt" "fmt"
"flag"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -194,7 +194,26 @@ func runBench(questions []models.Question) ([]models.Answer, error) {
logger.Error("failed to parse llm response", "error", err) logger.Error("failed to parse llm response", "error", err)
continue continue
} }
a := models.Answer{Q: q, Answer: respText, Model: "default"} // model name from resp?
// Check if the response indicates tool usage
var toolCall models.ToolCallInfo
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
// Extract tool name from the marker
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
toolCall = models.ToolCallInfo{
Name: toolName,
// Arguments would need to be parsed from the actual response
}
// Remove the marker from the response text
respText = fmt.Sprintf("Used tool: %s", toolName)
}
a := models.Answer{
Q: q,
Answer: respText,
Model: "default", // model name from resp?
ToolCall: toolCall,
}
answers = append(answers, a) answers = append(answers, a)
} }
return answers, nil return answers, nil

View File

@@ -10,13 +10,12 @@ type Answer struct {
Q Question Q Question
Answer string `json:"answer"` Answer string `json:"answer"`
Model string `json:"model"` Model string `json:"model"`
// resp time? ToolCall ToolCallInfo `json:"tool_call,omitempty"`
ToolCall ToolCall `json:"tool_call"`
} }
type ToolCall struct { type ToolCallInfo struct {
Name string `json:"name"` Name string `json:"name"`
Args map[string]ToolArgProps `json:"args"` Args map[string]string `json:"args"`
} }
// type OpenRouterResp struct { // type OpenRouterResp struct {
@@ -44,13 +43,31 @@ type ToolCall struct {
// } `json:"usage"` // } `json:"usage"`
// } // }
type DSResp struct { type ToolCallFunction struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCall struct {
ID string `json:"id"` ID string `json:"id"`
Choices []struct { Type string `json:"type"`
Function ToolCallFunction `json:"function"`
}
type DSRespChoice struct {
Text string `json:"text"` Text string `json:"text"`
Index int `json:"index"` Index int `json:"index"`
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} `json:"choices"` Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} `json:"message,omitempty"`
}
type DSResp struct {
ID string `json:"id"`
Choices []DSRespChoice `json:"choices"`
Created int `json:"created"` Created int `json:"created"`
Model string `json:"model"` Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"grailbench/models" "grailbench/models"
"io" "io"
"log/slog" "log/slog"
@@ -35,8 +36,24 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
err := errors.New("empty choices in dsResp") err := errors.New("empty choices in dsResp")
return "", err return "", err
} }
text := dsResp.Choices[0].Text
return text, nil // Check if the response contains tool calls
choice := dsResp.Choices[0]
// Handle response with message field (OpenAI format)
if choice.Message.Role != "" {
if len(choice.Message.ToolCalls) > 0 {
// Handle tool call response
toolCall := choice.Message.ToolCalls[0]
// Return a special marker indicating tool usage
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
}
// Regular text response
return choice.Message.Content, nil
}
// Handle response with text field (legacy format)
return choice.Text, nil
} }
func (p *deepSeekParser) MakePayload(prompt string) io.Reader { func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
@@ -55,6 +72,7 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
NProbs int `json:"n_probs"` NProbs int `json:"n_probs"`
TopP float64 `json:"top_p"` TopP float64 `json:"top_p"`
Tools []models.Tool `json:"tools,omitempty"`
}{ }{
Model: "deepseek-chat", Model: "deepseek-chat",
Prompt: prompt, Prompt: prompt,
@@ -70,6 +88,7 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
Temperature: 1, Temperature: 1,
NProbs: 10, NProbs: 10,
TopP: 1, TopP: 1,
Tools: baseTools, // Include the tools in the request
} }
b, err := json.Marshal(payload) b, err := json.Marshal(payload)
if err != nil { if err != nil {
@@ -108,6 +127,7 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
Stream bool `json:"stream"` Stream bool `json:"stream"`
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"` TopP float64 `json:"top_p"`
Tools []models.Tool `json:"tools,omitempty"`
}{ }{
Model: "local-model", Model: "local-model",
Prompt: prompt, Prompt: prompt,
@@ -117,6 +137,7 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
Stream: false, Stream: false,
Temperature: 0.4, Temperature: 0.4,
TopP: 1, TopP: 1,
Tools: baseTools, // Include tools (though local model may not support them)
} }
b, err := json.Marshal(payload) b, err := json.Marshal(payload)
@@ -152,8 +173,24 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
err := errors.New("empty choices in resp") err := errors.New("empty choices in resp")
return "", err return "", err
} }
text := resp.Choices[0].Text
return text, nil // Check if the response contains tool calls
choice := resp.Choices[0]
// Handle response with message field (OpenAI format)
if choice.Message.Role != "" {
if len(choice.Message.ToolCalls) > 0 {
// Handle tool call response
toolCall := choice.Message.ToolCalls[0]
// Return a special marker indicating tool usage
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
}
// Regular text response
return choice.Message.Content, nil
}
// Handle response with text field (legacy format)
return choice.Text, nil
} }
func (p *openRouterParser) MakePayload(prompt string) io.Reader { func (p *openRouterParser) MakePayload(prompt string) io.Reader {
@@ -165,9 +202,11 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
payload := struct { payload := struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Tools []models.Tool `json:"tools,omitempty"`
}{ }{
Model: model, Model: model,
Prompt: prompt, Prompt: prompt,
Tools: baseTools, // Include the tools in the request
} }
b, err := json.Marshal(payload) b, err := json.Marshal(payload)