Feat: add tests
This commit is contained in:
30
main.go
30
main.go
@@ -180,6 +180,22 @@ func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName s
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func processLLMResponse(respText string) (string, models.ToolCallInfo) {
|
||||
// Check if the response indicates tool usage
|
||||
var toolCall models.ToolCallInfo
|
||||
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
|
||||
// Extract tool name from the marker
|
||||
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
|
||||
toolCall = models.ToolCallInfo{
|
||||
Name: toolName,
|
||||
// Arguments would need to be parsed from the actual response
|
||||
}
|
||||
// Remove the marker from the response text
|
||||
respText = fmt.Sprintf("Used tool: %s", toolName)
|
||||
}
|
||||
return respText, toolCall
|
||||
}
|
||||
|
||||
func runBench(questions []models.Question) ([]models.Answer, error) {
|
||||
answers := []models.Answer{}
|
||||
for _, q := range questions {
|
||||
@@ -195,18 +211,8 @@ func runBench(questions []models.Question) ([]models.Answer, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the response indicates tool usage
|
||||
var toolCall models.ToolCallInfo
|
||||
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
|
||||
// Extract tool name from the marker
|
||||
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
|
||||
toolCall = models.ToolCallInfo{
|
||||
Name: toolName,
|
||||
// Arguments would need to be parsed from the actual response
|
||||
}
|
||||
// Remove the marker from the response text
|
||||
respText = fmt.Sprintf("Used tool: %s", toolName)
|
||||
}
|
||||
// Process the response to detect tool usage
|
||||
respText, toolCall := processLLMResponse(respText)
|
||||
|
||||
a := models.Answer{
|
||||
Q: q,
|
||||
|
130
parser_integration_test.go
Normal file
130
parser_integration_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestParserToolDetection tests the parser implementations with mock responses
|
||||
func TestParserToolDetection(t *testing.T) {
|
||||
// Create a logger for testing
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Test cases with different response types
|
||||
testCases := []struct {
|
||||
name string
|
||||
responseJSON string
|
||||
expectedOutput string
|
||||
isToolCall bool
|
||||
expectedTool string
|
||||
}{
|
||||
{
|
||||
name: "RegularTextResponse",
|
||||
responseJSON: `{
|
||||
"id": "test-id",
|
||||
"object": "text.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-chat",
|
||||
"choices": [
|
||||
{
|
||||
"text": "The capital of France is Paris.",
|
||||
"index": 0,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedOutput: "The capital of France is Paris.",
|
||||
isToolCall: false,
|
||||
},
|
||||
{
|
||||
name: "ToolCallResponse",
|
||||
responseJSON: `{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-chat",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "tool_calls",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_timestamp",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedOutput: "[TOOL_CALL:get_current_timestamp]",
|
||||
isToolCall: true,
|
||||
expectedTool: "get_current_timestamp",
|
||||
},
|
||||
{
|
||||
name: "RegularMessageResponse",
|
||||
responseJSON: `{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-chat",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello, how can I help you today?"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedOutput: "Hello, how can I help you today?",
|
||||
isToolCall: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test DeepSeekParser
|
||||
t.Run("DeepSeekParser", func(t *testing.T) {
|
||||
parser := NewDeepSeekParser(logger)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
||||
if err != nil {
|
||||
t.Errorf("ParseBytes returned an error: %v", err)
|
||||
}
|
||||
|
||||
if result != tc.expectedOutput {
|
||||
t.Errorf("Expected %s, got %s", tc.expectedOutput, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Test OpenRouterParser
|
||||
t.Run("OpenRouterParser", func(t *testing.T) {
|
||||
parser := NewOpenRouterParser(logger)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
||||
if err != nil {
|
||||
t.Errorf("ParseBytes returned an error: %v", err)
|
||||
}
|
||||
|
||||
if result != tc.expectedOutput {
|
||||
t.Errorf("Expected %s, got %s", tc.expectedOutput, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
192
parser_test.go
Normal file
192
parser_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDeepSeekParser_ParseBytes(t *testing.T) {
|
||||
// Create a logger for testing
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Create a new deepSeekParser
|
||||
parser := NewDeepSeekParser(logger)
|
||||
|
||||
// Test case 1: Regular text response
|
||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||
// Mock response JSON for a regular text response
|
||||
responseJSON := `{
|
||||
"id": "test-id",
|
||||
"object": "text.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-chat",
|
||||
"choices": [
|
||||
{
|
||||
"text": "This is a regular response",
|
||||
"index": 0,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||
if err != nil {
|
||||
t.Errorf("ParseBytes returned an error: %v", err)
|
||||
}
|
||||
|
||||
expected := "This is a regular response"
|
||||
if result != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, result)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 2: Tool call response
|
||||
t.Run("ToolCallResponse", func(t *testing.T) {
|
||||
// Mock response JSON for a tool call response
|
||||
responseJSON := `{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-chat",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "tool_calls",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_timestamp",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||
if err != nil {
|
||||
t.Errorf("ParseBytes returned an error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[TOOL_CALL:get_current_timestamp]"
|
||||
if result != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, result)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 3: Regular message response (OpenAI format)
|
||||
t.Run("RegularMessageResponse", func(t *testing.T) {
|
||||
// Mock response JSON for a regular message response
|
||||
responseJSON := `{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-chat",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a regular message response"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||
if err != nil {
|
||||
t.Errorf("ParseBytes returned an error: %v", err)
|
||||
}
|
||||
|
||||
expected := "This is a regular message response"
|
||||
if result != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
||||
// Create a logger for testing
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Create a new openRouterParser
|
||||
parser := NewOpenRouterParser(logger)
|
||||
|
||||
// Test case 1: Regular text response
|
||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||
// Mock response JSON for a regular text response
|
||||
responseJSON := `{
|
||||
"id": "test-id",
|
||||
"object": "text.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-r1",
|
||||
"choices": [
|
||||
{
|
||||
"text": "This is a regular response",
|
||||
"index": 0,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||
if err != nil {
|
||||
t.Errorf("ParseBytes returned an error: %v", err)
|
||||
}
|
||||
|
||||
expected := "This is a regular response"
|
||||
if result != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, result)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 2: Tool call response
|
||||
t.Run("ToolCallResponse", func(t *testing.T) {
|
||||
// Mock response JSON for a tool call response
|
||||
responseJSON := `{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "deepseek-r1",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "tool_calls",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "send_email",
|
||||
"arguments": "{\"address\":\"test@example.com\",\"title\":\"Test\",\"body\":\"This is a test email\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||
if err != nil {
|
||||
t.Errorf("ParseBytes returned an error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[TOOL_CALL:send_email]"
|
||||
if result != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, result)
|
||||
}
|
||||
})
|
||||
}
|
69
tool_detection_test.go
Normal file
69
tool_detection_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProcessLLMResponse tests the processing of LLM responses to detect tool usage
|
||||
func TestProcessLLMResponse(t *testing.T) {
|
||||
// Test data
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedOutput string
|
||||
expectedToolName string
|
||||
expectToolCall bool
|
||||
}{
|
||||
{
|
||||
name: "RegularTextResponse",
|
||||
input: "The capital of France is Paris.",
|
||||
expectedOutput: "The capital of France is Paris.",
|
||||
expectedToolName: "",
|
||||
expectToolCall: false,
|
||||
},
|
||||
{
|
||||
name: "ToolCallResponse",
|
||||
input: "[TOOL_CALL:get_current_timestamp]",
|
||||
expectedOutput: "Used tool: get_current_timestamp",
|
||||
expectedToolName: "get_current_timestamp",
|
||||
expectToolCall: true,
|
||||
},
|
||||
{
|
||||
name: "AnotherToolCallResponse",
|
||||
input: "[TOOL_CALL:send_email]",
|
||||
expectedOutput: "Used tool: send_email",
|
||||
expectedToolName: "send_email",
|
||||
expectToolCall: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// This mimics the logic in processLLMResponse function
|
||||
respText := tc.input
|
||||
var toolName string
|
||||
|
||||
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
|
||||
// Extract tool name from the marker
|
||||
toolName = strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
|
||||
// Remove the marker from the response text
|
||||
respText = "Used tool: " + toolName
|
||||
}
|
||||
|
||||
if respText != tc.expectedOutput {
|
||||
t.Errorf("Expected output '%s', got '%s'", tc.expectedOutput, respText)
|
||||
}
|
||||
|
||||
if tc.expectToolCall {
|
||||
if toolName != tc.expectedToolName {
|
||||
t.Errorf("Expected tool name '%s', got '%s'", tc.expectedToolName, toolName)
|
||||
}
|
||||
} else {
|
||||
if toolName != "" {
|
||||
t.Errorf("Expected no tool call, but got tool name '%s'", toolName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user