Feat: llm tool use
This commit is contained in:
95
parser.go
95
parser.go
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"grailbench/models"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -35,26 +36,43 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
||||
err := errors.New("empty choices in dsResp")
|
||||
return "", err
|
||||
}
|
||||
text := dsResp.Choices[0].Text
|
||||
return text, nil
|
||||
|
||||
// Check if the response contains tool calls
|
||||
choice := dsResp.Choices[0]
|
||||
|
||||
// Handle response with message field (OpenAI format)
|
||||
if choice.Message.Role != "" {
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Handle tool call response
|
||||
toolCall := choice.Message.ToolCalls[0]
|
||||
// Return a special marker indicating tool usage
|
||||
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
|
||||
}
|
||||
// Regular text response
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
// Handle response with text field (legacy format)
|
||||
return choice.Text, nil
|
||||
}
|
||||
|
||||
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
||||
payload := struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Echo bool `json:"echo"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
Logprobs int `json:"logprobs"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
Stop interface{} `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions interface{} `json:"stream_options"`
|
||||
Suffix interface{} `json:"suffix"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
NProbs int `json:"n_probs"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Echo bool `json:"echo"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
Logprobs int `json:"logprobs"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
Stop interface{} `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions interface{} `json:"stream_options"`
|
||||
Suffix interface{} `json:"suffix"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
NProbs int `json:"n_probs"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Tools []models.Tool `json:"tools,omitempty"`
|
||||
}{
|
||||
Model: "deepseek-chat",
|
||||
Prompt: prompt,
|
||||
@@ -70,6 +88,7 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
||||
Temperature: 1,
|
||||
NProbs: 10,
|
||||
TopP: 1,
|
||||
Tools: baseTools, // Include the tools in the request
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -100,14 +119,15 @@ func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
||||
|
||||
func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
||||
payload := struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stop []string `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stop []string `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Tools []models.Tool `json:"tools,omitempty"`
|
||||
}{
|
||||
Model: "local-model",
|
||||
Prompt: prompt,
|
||||
@@ -117,6 +137,7 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
||||
Stream: false,
|
||||
Temperature: 0.4,
|
||||
TopP: 1,
|
||||
Tools: baseTools, // Include tools (though local model may not support them)
|
||||
}
|
||||
|
||||
b, err := json.Marshal(payload)
|
||||
@@ -152,8 +173,24 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
||||
err := errors.New("empty choices in resp")
|
||||
return "", err
|
||||
}
|
||||
text := resp.Choices[0].Text
|
||||
return text, nil
|
||||
|
||||
// Check if the response contains tool calls
|
||||
choice := resp.Choices[0]
|
||||
|
||||
// Handle response with message field (OpenAI format)
|
||||
if choice.Message.Role != "" {
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Handle tool call response
|
||||
toolCall := choice.Message.ToolCalls[0]
|
||||
// Return a special marker indicating tool usage
|
||||
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
|
||||
}
|
||||
// Regular text response
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
// Handle response with text field (legacy format)
|
||||
return choice.Text, nil
|
||||
}
|
||||
|
||||
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||
@@ -163,11 +200,13 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||
// Get next model index using atomic addition for thread safety
|
||||
p.modelIndex++
|
||||
payload := struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Tools []models.Tool `json:"tools,omitempty"`
|
||||
}{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Tools: baseTools, // Include the tools in the request
|
||||
}
|
||||
|
||||
b, err := json.Marshal(payload)
|
||||
|
Reference in New Issue
Block a user