Feat: llm tool use

This commit is contained in:
Grail Finder
2025-09-05 14:03:17 +03:00
parent 8b88d2d824
commit 8699b1a84e
3 changed files with 119 additions and 44 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"grailbench/models"
"io"
"log/slog"
@@ -35,26 +36,43 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
err := errors.New("empty choices in dsResp")
return "", err
}
text := dsResp.Choices[0].Text
return text, nil
// Check if the response contains tool calls
choice := dsResp.Choices[0]
// Handle response with message field (OpenAI format)
if choice.Message.Role != "" {
if len(choice.Message.ToolCalls) > 0 {
// Handle tool call response
toolCall := choice.Message.ToolCalls[0]
// Return a special marker indicating tool usage
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
}
// Regular text response
return choice.Message.Content, nil
}
// Handle response with text field (legacy format)
return choice.Text, nil
}
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
payload := struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Echo bool `json:"echo"`
FrequencyPenalty float64 `json:"frequency_penalty"`
Logprobs int `json:"logprobs"`
MaxTokens int `json:"max_tokens"`
PresencePenalty float64 `json:"presence_penalty"`
Stop interface{} `json:"stop"`
Stream bool `json:"stream"`
StreamOptions interface{} `json:"stream_options"`
Suffix interface{} `json:"suffix"`
Temperature float64 `json:"temperature"`
NProbs int `json:"n_probs"`
TopP float64 `json:"top_p"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Echo bool `json:"echo"`
FrequencyPenalty float64 `json:"frequency_penalty"`
Logprobs int `json:"logprobs"`
MaxTokens int `json:"max_tokens"`
PresencePenalty float64 `json:"presence_penalty"`
Stop interface{} `json:"stop"`
Stream bool `json:"stream"`
StreamOptions interface{} `json:"stream_options"`
Suffix interface{} `json:"suffix"`
Temperature float64 `json:"temperature"`
NProbs int `json:"n_probs"`
TopP float64 `json:"top_p"`
Tools []models.Tool `json:"tools,omitempty"`
}{
Model: "deepseek-chat",
Prompt: prompt,
@@ -70,6 +88,7 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
Temperature: 1,
NProbs: 10,
TopP: 1,
Tools: baseTools, // Include the tools in the request
}
b, err := json.Marshal(payload)
if err != nil {
@@ -100,14 +119,15 @@ func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
payload := struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float64 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
Stop []string `json:"stop"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float64 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
Stop []string `json:"stop"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
Tools []models.Tool `json:"tools,omitempty"`
}{
Model: "local-model",
Prompt: prompt,
@@ -117,6 +137,7 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
Stream: false,
Temperature: 0.4,
TopP: 1,
Tools: baseTools, // Include tools (though local model may not support them)
}
b, err := json.Marshal(payload)
@@ -152,8 +173,24 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
err := errors.New("empty choices in resp")
return "", err
}
text := resp.Choices[0].Text
return text, nil
// Check if the response contains tool calls
choice := resp.Choices[0]
// Handle response with message field (OpenAI format)
if choice.Message.Role != "" {
if len(choice.Message.ToolCalls) > 0 {
// Handle tool call response
toolCall := choice.Message.ToolCalls[0]
// Return a special marker indicating tool usage
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
}
// Regular text response
return choice.Message.Content, nil
}
// Handle response with text field (legacy format)
return choice.Text, nil
}
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
@@ -163,11 +200,13 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
// Get next model index using atomic addition for thread safety
p.modelIndex++
payload := struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Tools []models.Tool `json:"tools,omitempty"`
}{
Model: model,
Prompt: prompt,
Tools: baseTools, // Include the tools in the request
}
b, err := json.Marshal(payload)