diff --git a/config/config.go b/config/config.go index 24daa47..7e10bb9 100644 --- a/config/config.go +++ b/config/config.go @@ -12,6 +12,7 @@ type Config struct { APIToken string `toml:"APIToken"` QuestionsPath string `toml:"QuestionsPath"` RPScenariosDir string `toml:"RPScenariosDir"` + ModelName string `toml:"ModelName"` OutDir string `toml:"OutDir"` MaxTurns int `toml:"MaxTurns"` } @@ -30,6 +31,7 @@ func LoadConfigOrDefault(fn string) *Config { config.RPScenariosDir = "scenarios" config.OutDir = "results" config.MaxTurns = 10 + config.ModelName = "deepseek/deepseek-chat-v3-0324:free" } return config } diff --git a/main.go b/main.go index 11f8261..2961f02 100644 --- a/main.go +++ b/main.go @@ -65,9 +65,9 @@ func loadRPScenarios(dir string) ([]*models.RPScenario, error) { func chooseParser(apiURL string) RespParser { if apiURL == cfg.RemoteAPI { - return NewOpenRouterParser(logger) + return NewOpenRouterParser(logger, cfg) } - return NewLCPRespParser(logger) + return NewLCPRespParser(logger, cfg) } func init() { diff --git a/parser.go b/parser.go index 81a5c0e..d25324c 100644 --- a/parser.go +++ b/parser.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "grailbench/config" "grailbench/models" "io" "log/slog" @@ -18,10 +19,11 @@ type RespParser interface { // DeepSeekParser: deepseek implementation of RespParser type deepSeekParser struct { log *slog.Logger + cfg *config.Config } -func NewDeepSeekParser(log *slog.Logger) *deepSeekParser { - return &deepSeekParser{log: log} +func NewDeepSeekParser(log *slog.Logger, cfg *config.Config) *deepSeekParser { + return &deepSeekParser{log: log, cfg: cfg} } func (p *deepSeekParser) ParseBytes(body []byte) (string, error) { @@ -36,10 +38,10 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) { err := errors.New("empty choices in dsResp") return "", err } - + // Check if the response contains tool calls choice := dsResp.Choices[0] - + // Handle response with message field (OpenAI format) if choice.Message.Role != "" { if len(choice.Message.ToolCalls) > 0 { @@ -51,7 +53,7 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) { // Regular text response return choice.Message.Content, nil } - + // Handle response with text field (legacy format) return choice.Text, nil } @@ -101,10 +103,11 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader { // llama.cpp implementation of RespParser type lcpRespParser struct { log *slog.Logger + cfg *config.Config } -func NewLCPRespParser(log *slog.Logger) *lcpRespParser { - return &lcpRespParser{log: log} +func NewLCPRespParser(log *slog.Logger, cfg *config.Config) *lcpRespParser { + return &lcpRespParser{log: log, cfg: cfg} } func (p *lcpRespParser) ParseBytes(body []byte) (string, error) { @@ -150,17 +153,19 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader { } type openRouterParser struct { - log *slog.Logger - modelIndex uint32 - useChatAPI bool + log *slog.Logger + cfg *config.Config + modelIndex uint32 + useChatAPI bool supportsTools bool } -func NewOpenRouterParser(log *slog.Logger) *openRouterParser { +func NewOpenRouterParser(log *slog.Logger, cfg *config.Config) *openRouterParser { return &openRouterParser{ - log: log, - modelIndex: 0, - useChatAPI: false, // Default to completion API which is more widely supported + log: log, + cfg: cfg, + modelIndex: 0, + useChatAPI: false, // Default to completion API which is more widely supported supportsTools: false, // Don't assume tool support } } @@ -178,9 +183,9 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) { err := errors.New("empty choices in openrouter chat response") return "", err } - + choice := resp.Choices[0] - + // Check if the response contains tool calls if len(choice.Message.ToolCalls) > 0 { // Handle tool call response @@ -188,11 +193,11 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) { // Return a special marker indicating tool usage return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil } - + // Regular text response return choice.Message.Content, nil } - + // If using completion API, parse as text completion response (no tool calls) resp := models.ORCompletionResp{} if err := json.Unmarshal(body, &resp); err != nil { @@ -204,7 +209,7 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) { err := errors.New("empty choices in openrouter completion response") return "", err } - + // Return the text content return resp.Choices[0].Text, nil } @@ -213,10 +218,10 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader { if p.useChatAPI { // Use chat completions API with messages format (supports tool calls) payload := struct { - Model string `json:"model"` + Model string `json:"model"` Messages []models.RoleMsg `json:"messages"` }{ - Model: "openai/gpt-4o-mini", + Model: p.cfg.ModelName, Messages: []models.RoleMsg{ {Role: "user", Content: prompt}, }, @@ -230,13 +235,13 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader { p.log.Debug("made openrouter chat payload", "payload", string(b)) return bytes.NewReader(b) } - + // Use completions API with prompt format (no tool calls) payload := struct { Model string `json:"model"` Prompt string `json:"prompt"` }{ - Model: "openai/gpt-4o-mini", + Model: p.cfg.ModelName, Prompt: prompt, } diff --git a/parser_integration_test.go b/parser_integration_test.go index 3cd722f..ea2eb34 100644 --- a/parser_integration_test.go +++ b/parser_integration_test.go @@ -1,6 +1,7 @@ package main import ( + "grailbench/config" "log/slog" "os" "testing" @@ -94,8 +95,8 @@ func TestParserToolDetection(t *testing.T) { // Test DeepSeekParser t.Run("DeepSeekParser", func(t *testing.T) { - parser := NewDeepSeekParser(logger) - + parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml")) + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { result, err := parser.ParseBytes([]byte(tc.responseJSON)) @@ -112,8 +113,8 @@ func TestParserToolDetection(t *testing.T) { // Test OpenRouterParser t.Run("OpenRouterParser", func(t *testing.T) { - parser := NewOpenRouterParser(logger) - + parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml")) + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { result, err := parser.ParseBytes([]byte(tc.responseJSON)) @@ -127,4 +128,5 @@ func TestParserToolDetection(t *testing.T) { }) } }) -} \ No newline at end of file +} + diff --git a/parser_test.go b/parser_test.go index 30c8b0e..0114082 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,6 +1,7 @@ package main import ( + "grailbench/config" "log/slog" "os" "testing" @@ -11,7 +12,7 @@ func TestDeepSeekParser_ParseBytes(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) // Create a new deepSeekParser - parser := NewDeepSeekParser(logger) + parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml")) // Test case 1: Regular text response t.Run("RegularTextResponse", func(t *testing.T) { @@ -119,7 +120,7 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) // Create a new openRouterParser - parser := NewOpenRouterParser(logger) + parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml")) // Test case 1: Regular text response t.Run("RegularTextResponse", func(t *testing.T) { @@ -189,4 +190,5 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) { t.Errorf("Expected %s, got %s", expected, result) } }) -} \ No newline at end of file +} +