Feat: llm tool use
This commit is contained in:
23
main.go
23
main.go
@@ -4,8 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"flag"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -194,7 +194,26 @@ func runBench(questions []models.Question) ([]models.Answer, error) {
|
|||||||
logger.Error("failed to parse llm response", "error", err)
|
logger.Error("failed to parse llm response", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
a := models.Answer{Q: q, Answer: respText, Model: "default"} // model name from resp?
|
|
||||||
|
// Check if the response indicates tool usage
|
||||||
|
var toolCall models.ToolCallInfo
|
||||||
|
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
|
||||||
|
// Extract tool name from the marker
|
||||||
|
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
|
||||||
|
toolCall = models.ToolCallInfo{
|
||||||
|
Name: toolName,
|
||||||
|
// Arguments would need to be parsed from the actual response
|
||||||
|
}
|
||||||
|
// Remove the marker from the response text
|
||||||
|
respText = fmt.Sprintf("Used tool: %s", toolName)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := models.Answer{
|
||||||
|
Q: q,
|
||||||
|
Answer: respText,
|
||||||
|
Model: "default", // model name from resp?
|
||||||
|
ToolCall: toolCall,
|
||||||
|
}
|
||||||
answers = append(answers, a)
|
answers = append(answers, a)
|
||||||
}
|
}
|
||||||
return answers, nil
|
return answers, nil
|
||||||
|
@@ -10,13 +10,12 @@ type Answer struct {
|
|||||||
Q Question
|
Q Question
|
||||||
Answer string `json:"answer"`
|
Answer string `json:"answer"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
// resp time?
|
ToolCall ToolCallInfo `json:"tool_call,omitempty"`
|
||||||
ToolCall ToolCall `json:"tool_call"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCallInfo struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Args map[string]ToolArgProps `json:"args"`
|
Args map[string]string `json:"args"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// type OpenRouterResp struct {
|
// type OpenRouterResp struct {
|
||||||
@@ -44,13 +43,31 @@ type ToolCall struct {
|
|||||||
// } `json:"usage"`
|
// } `json:"usage"`
|
||||||
// }
|
// }
|
||||||
|
|
||||||
type DSResp struct {
|
type ToolCallFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Choices []struct {
|
Type string `json:"type"`
|
||||||
|
Function ToolCallFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DSRespChoice struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
} `json:"choices"`
|
Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
} `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DSResp struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Choices []DSRespChoice `json:"choices"`
|
||||||
Created int `json:"created"`
|
Created int `json:"created"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
47
parser.go
47
parser.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"grailbench/models"
|
"grailbench/models"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -35,8 +36,24 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
|||||||
err := errors.New("empty choices in dsResp")
|
err := errors.New("empty choices in dsResp")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
text := dsResp.Choices[0].Text
|
|
||||||
return text, nil
|
// Check if the response contains tool calls
|
||||||
|
choice := dsResp.Choices[0]
|
||||||
|
|
||||||
|
// Handle response with message field (OpenAI format)
|
||||||
|
if choice.Message.Role != "" {
|
||||||
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
|
// Handle tool call response
|
||||||
|
toolCall := choice.Message.ToolCalls[0]
|
||||||
|
// Return a special marker indicating tool usage
|
||||||
|
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
|
||||||
|
}
|
||||||
|
// Regular text response
|
||||||
|
return choice.Message.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle response with text field (legacy format)
|
||||||
|
return choice.Text, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
||||||
@@ -55,6 +72,7 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
|||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
NProbs int `json:"n_probs"`
|
NProbs int `json:"n_probs"`
|
||||||
TopP float64 `json:"top_p"`
|
TopP float64 `json:"top_p"`
|
||||||
|
Tools []models.Tool `json:"tools,omitempty"`
|
||||||
}{
|
}{
|
||||||
Model: "deepseek-chat",
|
Model: "deepseek-chat",
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
@@ -70,6 +88,7 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
|||||||
Temperature: 1,
|
Temperature: 1,
|
||||||
NProbs: 10,
|
NProbs: 10,
|
||||||
TopP: 1,
|
TopP: 1,
|
||||||
|
Tools: baseTools, // Include the tools in the request
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(payload)
|
b, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -108,6 +127,7 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
TopP float64 `json:"top_p"`
|
TopP float64 `json:"top_p"`
|
||||||
|
Tools []models.Tool `json:"tools,omitempty"`
|
||||||
}{
|
}{
|
||||||
Model: "local-model",
|
Model: "local-model",
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
@@ -117,6 +137,7 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
|||||||
Stream: false,
|
Stream: false,
|
||||||
Temperature: 0.4,
|
Temperature: 0.4,
|
||||||
TopP: 1,
|
TopP: 1,
|
||||||
|
Tools: baseTools, // Include tools (though local model may not support them)
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := json.Marshal(payload)
|
b, err := json.Marshal(payload)
|
||||||
@@ -152,8 +173,24 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
|||||||
err := errors.New("empty choices in resp")
|
err := errors.New("empty choices in resp")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
text := resp.Choices[0].Text
|
|
||||||
return text, nil
|
// Check if the response contains tool calls
|
||||||
|
choice := resp.Choices[0]
|
||||||
|
|
||||||
|
// Handle response with message field (OpenAI format)
|
||||||
|
if choice.Message.Role != "" {
|
||||||
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
|
// Handle tool call response
|
||||||
|
toolCall := choice.Message.ToolCalls[0]
|
||||||
|
// Return a special marker indicating tool usage
|
||||||
|
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
|
||||||
|
}
|
||||||
|
// Regular text response
|
||||||
|
return choice.Message.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle response with text field (legacy format)
|
||||||
|
return choice.Text, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||||
@@ -165,9 +202,11 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
|||||||
payload := struct {
|
payload := struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Tools []models.Tool `json:"tools,omitempty"`
|
||||||
}{
|
}{
|
||||||
Model: model,
|
Model: model,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
|
Tools: baseTools, // Include the tools in the request
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := json.Marshal(payload)
|
b, err := json.Marshal(payload)
|
||||||
|
Reference in New Issue
Block a user