Feat: add tests
This commit is contained in:
30
main.go
30
main.go
@@ -180,6 +180,22 @@ func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName s
|
|||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func processLLMResponse(respText string) (string, models.ToolCallInfo) {
|
||||||
|
// Check if the response indicates tool usage
|
||||||
|
var toolCall models.ToolCallInfo
|
||||||
|
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
|
||||||
|
// Extract tool name from the marker
|
||||||
|
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
|
||||||
|
toolCall = models.ToolCallInfo{
|
||||||
|
Name: toolName,
|
||||||
|
// Arguments would need to be parsed from the actual response
|
||||||
|
}
|
||||||
|
// Remove the marker from the response text
|
||||||
|
respText = fmt.Sprintf("Used tool: %s", toolName)
|
||||||
|
}
|
||||||
|
return respText, toolCall
|
||||||
|
}
|
||||||
|
|
||||||
func runBench(questions []models.Question) ([]models.Answer, error) {
|
func runBench(questions []models.Question) ([]models.Answer, error) {
|
||||||
answers := []models.Answer{}
|
answers := []models.Answer{}
|
||||||
for _, q := range questions {
|
for _, q := range questions {
|
||||||
@@ -195,18 +211,8 @@ func runBench(questions []models.Question) ([]models.Answer, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the response indicates tool usage
|
// Process the response to detect tool usage
|
||||||
var toolCall models.ToolCallInfo
|
respText, toolCall := processLLMResponse(respText)
|
||||||
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
|
|
||||||
// Extract tool name from the marker
|
|
||||||
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
|
|
||||||
toolCall = models.ToolCallInfo{
|
|
||||||
Name: toolName,
|
|
||||||
// Arguments would need to be parsed from the actual response
|
|
||||||
}
|
|
||||||
// Remove the marker from the response text
|
|
||||||
respText = fmt.Sprintf("Used tool: %s", toolName)
|
|
||||||
}
|
|
||||||
|
|
||||||
a := models.Answer{
|
a := models.Answer{
|
||||||
Q: q,
|
Q: q,
|
||||||
|
130
parser_integration_test.go
Normal file
130
parser_integration_test.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParserToolDetection tests the parser implementations with mock responses
|
||||||
|
func TestParserToolDetection(t *testing.T) {
|
||||||
|
// Create a logger for testing
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
// Test cases with different response types
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
responseJSON string
|
||||||
|
expectedOutput string
|
||||||
|
isToolCall bool
|
||||||
|
expectedTool string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RegularTextResponse",
|
||||||
|
responseJSON: `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "text.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "The capital of France is Paris.",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
expectedOutput: "The capital of France is Paris.",
|
||||||
|
isToolCall: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ToolCallResponse",
|
||||||
|
responseJSON: `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_timestamp",
|
||||||
|
"arguments": "{}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
expectedOutput: "[TOOL_CALL:get_current_timestamp]",
|
||||||
|
isToolCall: true,
|
||||||
|
expectedTool: "get_current_timestamp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RegularMessageResponse",
|
||||||
|
responseJSON: `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello, how can I help you today?"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
expectedOutput: "Hello, how can I help you today?",
|
||||||
|
isToolCall: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DeepSeekParser
|
||||||
|
t.Run("DeepSeekParser", func(t *testing.T) {
|
||||||
|
parser := NewDeepSeekParser(logger)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ParseBytes returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tc.expectedOutput {
|
||||||
|
t.Errorf("Expected %s, got %s", tc.expectedOutput, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test OpenRouterParser
|
||||||
|
t.Run("OpenRouterParser", func(t *testing.T) {
|
||||||
|
parser := NewOpenRouterParser(logger)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ParseBytes returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tc.expectedOutput {
|
||||||
|
t.Errorf("Expected %s, got %s", tc.expectedOutput, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
192
parser_test.go
Normal file
192
parser_test.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeepSeekParser_ParseBytes(t *testing.T) {
|
||||||
|
// Create a logger for testing
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
// Create a new deepSeekParser
|
||||||
|
parser := NewDeepSeekParser(logger)
|
||||||
|
|
||||||
|
// Test case 1: Regular text response
|
||||||
|
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||||
|
// Mock response JSON for a regular text response
|
||||||
|
responseJSON := `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "text.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a regular response",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ParseBytes returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "This is a regular response"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 2: Tool call response
|
||||||
|
t.Run("ToolCallResponse", func(t *testing.T) {
|
||||||
|
// Mock response JSON for a tool call response
|
||||||
|
responseJSON := `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_timestamp",
|
||||||
|
"arguments": "{}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ParseBytes returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "[TOOL_CALL:get_current_timestamp]"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 3: Regular message response (OpenAI format)
|
||||||
|
t.Run("RegularMessageResponse", func(t *testing.T) {
|
||||||
|
// Mock response JSON for a regular message response
|
||||||
|
responseJSON := `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "This is a regular message response"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ParseBytes returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "This is a regular message response"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
||||||
|
// Create a logger for testing
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
// Create a new openRouterParser
|
||||||
|
parser := NewOpenRouterParser(logger)
|
||||||
|
|
||||||
|
// Test case 1: Regular text response
|
||||||
|
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||||
|
// Mock response JSON for a regular text response
|
||||||
|
responseJSON := `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "text.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-r1",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a regular response",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ParseBytes returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "This is a regular response"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 2: Tool call response
|
||||||
|
t.Run("ToolCallResponse", func(t *testing.T) {
|
||||||
|
// Mock response JSON for a tool call response
|
||||||
|
responseJSON := `{
|
||||||
|
"id": "test-id",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "deepseek-r1",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "send_email",
|
||||||
|
"arguments": "{\"address\":\"test@example.com\",\"title\":\"Test\",\"body\":\"This is a test email\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := parser.ParseBytes([]byte(responseJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ParseBytes returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "[TOOL_CALL:send_email]"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
69
tool_detection_test.go
Normal file
69
tool_detection_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestProcessLLMResponse tests the processing of LLM responses to detect tool usage
|
||||||
|
func TestProcessLLMResponse(t *testing.T) {
|
||||||
|
// Test data
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedOutput string
|
||||||
|
expectedToolName string
|
||||||
|
expectToolCall bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RegularTextResponse",
|
||||||
|
input: "The capital of France is Paris.",
|
||||||
|
expectedOutput: "The capital of France is Paris.",
|
||||||
|
expectedToolName: "",
|
||||||
|
expectToolCall: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ToolCallResponse",
|
||||||
|
input: "[TOOL_CALL:get_current_timestamp]",
|
||||||
|
expectedOutput: "Used tool: get_current_timestamp",
|
||||||
|
expectedToolName: "get_current_timestamp",
|
||||||
|
expectToolCall: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AnotherToolCallResponse",
|
||||||
|
input: "[TOOL_CALL:send_email]",
|
||||||
|
expectedOutput: "Used tool: send_email",
|
||||||
|
expectedToolName: "send_email",
|
||||||
|
expectToolCall: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// This mimics the logic in processLLMResponse function
|
||||||
|
respText := tc.input
|
||||||
|
var toolName string
|
||||||
|
|
||||||
|
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
|
||||||
|
// Extract tool name from the marker
|
||||||
|
toolName = strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
|
||||||
|
// Remove the marker from the response text
|
||||||
|
respText = "Used tool: " + toolName
|
||||||
|
}
|
||||||
|
|
||||||
|
if respText != tc.expectedOutput {
|
||||||
|
t.Errorf("Expected output '%s', got '%s'", tc.expectedOutput, respText)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectToolCall {
|
||||||
|
if toolName != tc.expectedToolName {
|
||||||
|
t.Errorf("Expected tool name '%s', got '%s'", tc.expectedToolName, toolName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if toolName != "" {
|
||||||
|
t.Errorf("Expected no tool call, but got tool name '%s'", toolName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user