diff --git a/main.go b/main.go index ef44924..d58bf1a 100644 --- a/main.go +++ b/main.go @@ -180,6 +180,22 @@ func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName s return b.String() } +func processLLMResponse(respText string) (string, models.ToolCallInfo) { + // Check if the response indicates tool usage + var toolCall models.ToolCallInfo + if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") { + // Extract tool name from the marker + toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:") + toolCall = models.ToolCallInfo{ + Name: toolName, + // Arguments would need to be parsed from the actual response + } + // Remove the marker from the response text + respText = fmt.Sprintf("Used tool: %s", toolName) + } + return respText, toolCall +} + func runBench(questions []models.Question) ([]models.Answer, error) { answers := []models.Answer{} for _, q := range questions { @@ -195,18 +211,8 @@ func runBench(questions []models.Question) ([]models.Answer, error) { continue } - // Check if the response indicates tool usage - var toolCall models.ToolCallInfo - if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") { - // Extract tool name from the marker - toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:") - toolCall = models.ToolCallInfo{ - Name: toolName, - // Arguments would need to be parsed from the actual response - } - // Remove the marker from the response text - respText = fmt.Sprintf("Used tool: %s", toolName) - } + // Process the response to detect tool usage + respText, toolCall := processLLMResponse(respText) a := models.Answer{ Q: q, diff --git a/parser_integration_test.go b/parser_integration_test.go new file mode 100644 index 0000000..3cd722f --- /dev/null +++ b/parser_integration_test.go @@ -0,0 +1,130 @@ +package main + +import ( + "log/slog" + "os" + "testing" +) + +// TestParserToolDetection tests the parser implementations with mock responses +func TestParserToolDetection(t *testing.T) { + // Create a logger for testing + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + + // Test cases with different response types + testCases := []struct { + name string + responseJSON string + expectedOutput string + isToolCall bool + expectedTool string + }{ + { + name: "RegularTextResponse", + responseJSON: `{ + "id": "test-id", + "object": "text.completion", + "created": 1234567890, + "model": "deepseek-chat", + "choices": [ + { + "text": "The capital of France is Paris.", + "index": 0, + "finish_reason": "stop" + } + ] + }`, + expectedOutput: "The capital of France is Paris.", + isToolCall: false, + }, + { + name: "ToolCallResponse", + responseJSON: `{ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "deepseek-chat", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_current_timestamp", + "arguments": "{}" + } + } + ] + } + } + ] + }`, + expectedOutput: "[TOOL_CALL:get_current_timestamp]", + isToolCall: true, + expectedTool: "get_current_timestamp", + }, + { + name: "RegularMessageResponse", + responseJSON: `{ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "deepseek-chat", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "Hello, how can I help you today?" + } + } + ] + }`, + expectedOutput: "Hello, how can I help you today?", + isToolCall: false, + }, + } + + // Test DeepSeekParser + t.Run("DeepSeekParser", func(t *testing.T) { + parser := NewDeepSeekParser(logger) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := parser.ParseBytes([]byte(tc.responseJSON)) + if err != nil { + t.Errorf("ParseBytes returned an error: %v", err) + } + + if result != tc.expectedOutput { + t.Errorf("Expected %s, got %s", tc.expectedOutput, result) + } + }) + } + }) + + // Test OpenRouterParser + t.Run("OpenRouterParser", func(t *testing.T) { + parser := NewOpenRouterParser(logger) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := parser.ParseBytes([]byte(tc.responseJSON)) + if err != nil { + t.Errorf("ParseBytes returned an error: %v", err) + } + + if result != tc.expectedOutput { + t.Errorf("Expected %s, got %s", tc.expectedOutput, result) + } + }) + } + }) +} \ No newline at end of file diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..30c8b0e --- /dev/null +++ b/parser_test.go @@ -0,0 +1,192 @@ +package main + +import ( + "log/slog" + "os" + "testing" +) + +func TestDeepSeekParser_ParseBytes(t *testing.T) { + // Create a logger for testing + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + + // Create a new deepSeekParser + parser := NewDeepSeekParser(logger) + + // Test case 1: Regular text response + t.Run("RegularTextResponse", func(t *testing.T) { + // Mock response JSON for a regular text response + responseJSON := `{ + "id": "test-id", + "object": "text.completion", + "created": 1234567890, + "model": "deepseek-chat", + "choices": [ + { + "text": "This is a regular response", + "index": 0, + "finish_reason": "stop" + } + ] + }` + + result, err := parser.ParseBytes([]byte(responseJSON)) + if err != nil { + t.Errorf("ParseBytes returned an error: %v", err) + } + + expected := "This is a regular response" + if result != expected { + t.Errorf("Expected %s, got %s", expected, result) + } + }) + + // Test case 2: Tool call response + t.Run("ToolCallResponse", func(t *testing.T) { + // Mock response JSON for a tool call response + responseJSON := `{ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "deepseek-chat", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_current_timestamp", + "arguments": "{}" + } + } + ] + } + } + ] + }` + + result, err := parser.ParseBytes([]byte(responseJSON)) + if err != nil { + t.Errorf("ParseBytes returned an error: %v", err) + } + + expected := "[TOOL_CALL:get_current_timestamp]" + if result != expected { + t.Errorf("Expected %s, got %s", expected, result) + } + }) + + // Test case 3: Regular message response (OpenAI format) + t.Run("RegularMessageResponse", func(t *testing.T) { + // Mock response JSON for a regular message response + responseJSON := `{ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "deepseek-chat", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "This is a regular message response" + } + } + ] + }` + + result, err := parser.ParseBytes([]byte(responseJSON)) + if err != nil { + t.Errorf("ParseBytes returned an error: %v", err) + } + + expected := "This is a regular message response" + if result != expected { + t.Errorf("Expected %s, got %s", expected, result) + } + }) +} + +func TestOpenRouterParser_ParseBytes(t *testing.T) { + // Create a logger for testing + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + + // Create a new openRouterParser + parser := NewOpenRouterParser(logger) + + // Test case 1: Regular text response + t.Run("RegularTextResponse", func(t *testing.T) { + // Mock response JSON for a regular text response + responseJSON := `{ + "id": "test-id", + "object": "text.completion", + "created": 1234567890, + "model": "deepseek-r1", + "choices": [ + { + "text": "This is a regular response", + "index": 0, + "finish_reason": "stop" + } + ] + }` + + result, err := parser.ParseBytes([]byte(responseJSON)) + if err != nil { + t.Errorf("ParseBytes returned an error: %v", err) + } + + expected := "This is a regular response" + if result != expected { + t.Errorf("Expected %s, got %s", expected, result) + } + }) + + // Test case 2: Tool call response + t.Run("ToolCallResponse", func(t *testing.T) { + // Mock response JSON for a tool call response + responseJSON := `{ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "deepseek-r1", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "send_email", + "arguments": "{\"address\":\"test@example.com\",\"title\":\"Test\",\"body\":\"This is a test email\"}" + } + } + ] + } + } + ] + }` + + result, err := parser.ParseBytes([]byte(responseJSON)) + if err != nil { + t.Errorf("ParseBytes returned an error: %v", err) + } + + expected := "[TOOL_CALL:send_email]" + if result != expected { + t.Errorf("Expected %s, got %s", expected, result) + } + }) +} \ No newline at end of file diff --git a/tool_detection_test.go b/tool_detection_test.go new file mode 100644 index 0000000..2f42d54 --- /dev/null +++ b/tool_detection_test.go @@ -0,0 +1,69 @@ +package main + +import ( + "strings" + "testing" +) + +// TestProcessLLMResponse tests the processing of LLM responses to detect tool usage +func TestProcessLLMResponse(t *testing.T) { + // Test data + testCases := []struct { + name string + input string + expectedOutput string + expectedToolName string + expectToolCall bool + }{ + { + name: "RegularTextResponse", + input: "The capital of France is Paris.", + expectedOutput: "The capital of France is Paris.", + expectedToolName: "", + expectToolCall: false, + }, + { + name: "ToolCallResponse", + input: "[TOOL_CALL:get_current_timestamp]", + expectedOutput: "Used tool: get_current_timestamp", + expectedToolName: "get_current_timestamp", + expectToolCall: true, + }, + { + name: "AnotherToolCallResponse", + input: "[TOOL_CALL:send_email]", + expectedOutput: "Used tool: send_email", + expectedToolName: "send_email", + expectToolCall: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // This mimics the logic in processLLMResponse function + respText := tc.input + var toolName string + + if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") { + // Extract tool name from the marker + toolName = strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:") + // Remove the marker from the response text + respText = "Used tool: " + toolName + } + + if respText != tc.expectedOutput { + t.Errorf("Expected output '%s', got '%s'", tc.expectedOutput, respText) + } + + if tc.expectToolCall { + if toolName != tc.expectedToolName { + t.Errorf("Expected tool name '%s', got '%s'", tc.expectedToolName, toolName) + } + } else { + if toolName != "" { + t.Errorf("Expected no tool call, but got tool name '%s'", toolName) + } + } + }) + } +} \ No newline at end of file