Feat: add tests

This commit is contained in:
Grail Finder
2025-09-05 14:21:18 +03:00
parent 8699b1a84e
commit 53dc5a5e8d
4 changed files with 409 additions and 12 deletions

30
main.go
View File

@@ -180,6 +180,22 @@ func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName s
return b.String()
}
func processLLMResponse(respText string) (string, models.ToolCallInfo) {
// Check if the response indicates tool usage
var toolCall models.ToolCallInfo
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
// Extract tool name from the marker
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
toolCall = models.ToolCallInfo{
Name: toolName,
// Arguments would need to be parsed from the actual response
}
// Remove the marker from the response text
respText = fmt.Sprintf("Used tool: %s", toolName)
}
return respText, toolCall
}
func runBench(questions []models.Question) ([]models.Answer, error) {
answers := []models.Answer{}
for _, q := range questions {
@@ -195,18 +211,8 @@ func runBench(questions []models.Question) ([]models.Answer, error) {
continue
}
// Check if the response indicates tool usage
var toolCall models.ToolCallInfo
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
// Extract tool name from the marker
toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
toolCall = models.ToolCallInfo{
Name: toolName,
// Arguments would need to be parsed from the actual response
}
// Remove the marker from the response text
respText = fmt.Sprintf("Used tool: %s", toolName)
}
// Process the response to detect tool usage
respText, toolCall := processLLMResponse(respText)
a := models.Answer{
Q: q,

130
parser_integration_test.go Normal file
View File

@@ -0,0 +1,130 @@
package main
import (
"log/slog"
"os"
"testing"
)
// TestParserToolDetection tests the parser implementations with mock responses
func TestParserToolDetection(t *testing.T) {
// Create a logger for testing
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Test cases with different response types
testCases := []struct {
name string
responseJSON string
expectedOutput string
isToolCall bool
expectedTool string
}{
{
name: "RegularTextResponse",
responseJSON: `{
"id": "test-id",
"object": "text.completion",
"created": 1234567890,
"model": "deepseek-chat",
"choices": [
{
"text": "The capital of France is Paris.",
"index": 0,
"finish_reason": "stop"
}
]
}`,
expectedOutput: "The capital of France is Paris.",
isToolCall: false,
},
{
name: "ToolCallResponse",
responseJSON: `{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "deepseek-chat",
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_current_timestamp",
"arguments": "{}"
}
}
]
}
}
]
}`,
expectedOutput: "[TOOL_CALL:get_current_timestamp]",
isToolCall: true,
expectedTool: "get_current_timestamp",
},
{
name: "RegularMessageResponse",
responseJSON: `{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "deepseek-chat",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": "Hello, how can I help you today?"
}
}
]
}`,
expectedOutput: "Hello, how can I help you today?",
isToolCall: false,
},
}
// Test DeepSeekParser
t.Run("DeepSeekParser", func(t *testing.T) {
parser := NewDeepSeekParser(logger)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := parser.ParseBytes([]byte(tc.responseJSON))
if err != nil {
t.Errorf("ParseBytes returned an error: %v", err)
}
if result != tc.expectedOutput {
t.Errorf("Expected %s, got %s", tc.expectedOutput, result)
}
})
}
})
// Test OpenRouterParser
t.Run("OpenRouterParser", func(t *testing.T) {
parser := NewOpenRouterParser(logger)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := parser.ParseBytes([]byte(tc.responseJSON))
if err != nil {
t.Errorf("ParseBytes returned an error: %v", err)
}
if result != tc.expectedOutput {
t.Errorf("Expected %s, got %s", tc.expectedOutput, result)
}
})
}
})
}

192
parser_test.go Normal file
View File

@@ -0,0 +1,192 @@
package main
import (
"log/slog"
"os"
"testing"
)
func TestDeepSeekParser_ParseBytes(t *testing.T) {
// Create a logger for testing
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Create a new deepSeekParser
parser := NewDeepSeekParser(logger)
// Test case 1: Regular text response
t.Run("RegularTextResponse", func(t *testing.T) {
// Mock response JSON for a regular text response
responseJSON := `{
"id": "test-id",
"object": "text.completion",
"created": 1234567890,
"model": "deepseek-chat",
"choices": [
{
"text": "This is a regular response",
"index": 0,
"finish_reason": "stop"
}
]
}`
result, err := parser.ParseBytes([]byte(responseJSON))
if err != nil {
t.Errorf("ParseBytes returned an error: %v", err)
}
expected := "This is a regular response"
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}
})
// Test case 2: Tool call response
t.Run("ToolCallResponse", func(t *testing.T) {
// Mock response JSON for a tool call response
responseJSON := `{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "deepseek-chat",
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_current_timestamp",
"arguments": "{}"
}
}
]
}
}
]
}`
result, err := parser.ParseBytes([]byte(responseJSON))
if err != nil {
t.Errorf("ParseBytes returned an error: %v", err)
}
expected := "[TOOL_CALL:get_current_timestamp]"
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}
})
// Test case 3: Regular message response (OpenAI format)
t.Run("RegularMessageResponse", func(t *testing.T) {
// Mock response JSON for a regular message response
responseJSON := `{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "deepseek-chat",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": "This is a regular message response"
}
}
]
}`
result, err := parser.ParseBytes([]byte(responseJSON))
if err != nil {
t.Errorf("ParseBytes returned an error: %v", err)
}
expected := "This is a regular message response"
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}
})
}
func TestOpenRouterParser_ParseBytes(t *testing.T) {
// Create a logger for testing
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Create a new openRouterParser
parser := NewOpenRouterParser(logger)
// Test case 1: Regular text response
t.Run("RegularTextResponse", func(t *testing.T) {
// Mock response JSON for a regular text response
responseJSON := `{
"id": "test-id",
"object": "text.completion",
"created": 1234567890,
"model": "deepseek-r1",
"choices": [
{
"text": "This is a regular response",
"index": 0,
"finish_reason": "stop"
}
]
}`
result, err := parser.ParseBytes([]byte(responseJSON))
if err != nil {
t.Errorf("ParseBytes returned an error: %v", err)
}
expected := "This is a regular response"
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}
})
// Test case 2: Tool call response
t.Run("ToolCallResponse", func(t *testing.T) {
// Mock response JSON for a tool call response
responseJSON := `{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "deepseek-r1",
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "send_email",
"arguments": "{\"address\":\"test@example.com\",\"title\":\"Test\",\"body\":\"This is a test email\"}"
}
}
]
}
}
]
}`
result, err := parser.ParseBytes([]byte(responseJSON))
if err != nil {
t.Errorf("ParseBytes returned an error: %v", err)
}
expected := "[TOOL_CALL:send_email]"
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}
})
}

69
tool_detection_test.go Normal file
View File

@@ -0,0 +1,69 @@
package main
import (
"strings"
"testing"
)
// TestProcessLLMResponse tests the processing of LLM responses to detect tool usage
func TestProcessLLMResponse(t *testing.T) {
// Test data
testCases := []struct {
name string
input string
expectedOutput string
expectedToolName string
expectToolCall bool
}{
{
name: "RegularTextResponse",
input: "The capital of France is Paris.",
expectedOutput: "The capital of France is Paris.",
expectedToolName: "",
expectToolCall: false,
},
{
name: "ToolCallResponse",
input: "[TOOL_CALL:get_current_timestamp]",
expectedOutput: "Used tool: get_current_timestamp",
expectedToolName: "get_current_timestamp",
expectToolCall: true,
},
{
name: "AnotherToolCallResponse",
input: "[TOOL_CALL:send_email]",
expectedOutput: "Used tool: send_email",
expectedToolName: "send_email",
expectToolCall: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// This mimics the logic in processLLMResponse function
respText := tc.input
var toolName string
if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") {
// Extract tool name from the marker
toolName = strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:")
// Remove the marker from the response text
respText = "Used tool: " + toolName
}
if respText != tc.expectedOutput {
t.Errorf("Expected output '%s', got '%s'", tc.expectedOutput, respText)
}
if tc.expectToolCall {
if toolName != tc.expectedToolName {
t.Errorf("Expected tool name '%s', got '%s'", tc.expectedToolName, toolName)
}
} else {
if toolName != "" {
t.Errorf("Expected no tool call, but got tool name '%s'", toolName)
}
}
})
}
}