Enha: model name to config

This commit is contained in:
Grail Finder
2025-09-06 14:41:21 +03:00
parent 0276000bfa
commit c75ac433d4
5 changed files with 44 additions and 33 deletions

View File

@@ -12,6 +12,7 @@ type Config struct {
APIToken string `toml:"APIToken"` APIToken string `toml:"APIToken"`
QuestionsPath string `toml:"QuestionsPath"` QuestionsPath string `toml:"QuestionsPath"`
RPScenariosDir string `toml:"RPScenariosDir"` RPScenariosDir string `toml:"RPScenariosDir"`
ModelName string `toml:"ModelName"`
OutDir string `toml:"OutDir"` OutDir string `toml:"OutDir"`
MaxTurns int `toml:"MaxTurns"` MaxTurns int `toml:"MaxTurns"`
} }
@@ -30,6 +31,7 @@ func LoadConfigOrDefault(fn string) *Config {
config.RPScenariosDir = "scenarios" config.RPScenariosDir = "scenarios"
config.OutDir = "results" config.OutDir = "results"
config.MaxTurns = 10 config.MaxTurns = 10
config.ModelName = "deepseek/deepseek-chat-v3-0324:free"
} }
return config return config
} }

View File

@@ -65,9 +65,9 @@ func loadRPScenarios(dir string) ([]*models.RPScenario, error) {
func chooseParser(apiURL string) RespParser { func chooseParser(apiURL string) RespParser {
if apiURL == cfg.RemoteAPI { if apiURL == cfg.RemoteAPI {
return NewOpenRouterParser(logger) return NewOpenRouterParser(logger, cfg)
} }
return NewLCPRespParser(logger) return NewLCPRespParser(logger, cfg)
} }
func init() { func init() {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"grailbench/config"
"grailbench/models" "grailbench/models"
"io" "io"
"log/slog" "log/slog"
@@ -18,10 +19,11 @@ type RespParser interface {
// DeepSeekParser: deepseek implementation of RespParser // DeepSeekParser: deepseek implementation of RespParser
type deepSeekParser struct { type deepSeekParser struct {
log *slog.Logger log *slog.Logger
cfg *config.Config
} }
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser { func NewDeepSeekParser(log *slog.Logger, cfg *config.Config) *deepSeekParser {
return &deepSeekParser{log: log} return &deepSeekParser{log: log, cfg: cfg}
} }
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) { func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
@@ -101,10 +103,11 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
// llama.cpp implementation of RespParser // llama.cpp implementation of RespParser
type lcpRespParser struct { type lcpRespParser struct {
log *slog.Logger log *slog.Logger
cfg *config.Config
} }
func NewLCPRespParser(log *slog.Logger) *lcpRespParser { func NewLCPRespParser(log *slog.Logger, cfg *config.Config) *lcpRespParser {
return &lcpRespParser{log: log} return &lcpRespParser{log: log, cfg: cfg}
} }
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) { func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
@@ -150,17 +153,19 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
} }
type openRouterParser struct { type openRouterParser struct {
log *slog.Logger log *slog.Logger
modelIndex uint32 cfg *config.Config
useChatAPI bool modelIndex uint32
useChatAPI bool
supportsTools bool supportsTools bool
} }
func NewOpenRouterParser(log *slog.Logger) *openRouterParser { func NewOpenRouterParser(log *slog.Logger, cfg *config.Config) *openRouterParser {
return &openRouterParser{ return &openRouterParser{
log: log, log: log,
modelIndex: 0, cfg: cfg,
useChatAPI: false, // Default to completion API which is more widely supported modelIndex: 0,
useChatAPI: false, // Default to completion API which is more widely supported
supportsTools: false, // Don't assume tool support supportsTools: false, // Don't assume tool support
} }
} }
@@ -213,10 +218,10 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
if p.useChatAPI { if p.useChatAPI {
// Use chat completions API with messages format (supports tool calls) // Use chat completions API with messages format (supports tool calls)
payload := struct { payload := struct {
Model string `json:"model"` Model string `json:"model"`
Messages []models.RoleMsg `json:"messages"` Messages []models.RoleMsg `json:"messages"`
}{ }{
Model: "openai/gpt-4o-mini", Model: p.cfg.ModelName,
Messages: []models.RoleMsg{ Messages: []models.RoleMsg{
{Role: "user", Content: prompt}, {Role: "user", Content: prompt},
}, },
@@ -236,7 +241,7 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
}{ }{
Model: "openai/gpt-4o-mini", Model: p.cfg.ModelName,
Prompt: prompt, Prompt: prompt,
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"grailbench/config"
"log/slog" "log/slog"
"os" "os"
"testing" "testing"
@@ -94,7 +95,7 @@ func TestParserToolDetection(t *testing.T) {
// Test DeepSeekParser // Test DeepSeekParser
t.Run("DeepSeekParser", func(t *testing.T) { t.Run("DeepSeekParser", func(t *testing.T) {
parser := NewDeepSeekParser(logger) parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -112,7 +113,7 @@ func TestParserToolDetection(t *testing.T) {
// Test OpenRouterParser // Test OpenRouterParser
t.Run("OpenRouterParser", func(t *testing.T) { t.Run("OpenRouterParser", func(t *testing.T) {
parser := NewOpenRouterParser(logger) parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -128,3 +129,4 @@ func TestParserToolDetection(t *testing.T) {
} }
}) })
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"grailbench/config"
"log/slog" "log/slog"
"os" "os"
"testing" "testing"
@@ -11,7 +12,7 @@ func TestDeepSeekParser_ParseBytes(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Create a new deepSeekParser // Create a new deepSeekParser
parser := NewDeepSeekParser(logger) parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
// Test case 1: Regular text response // Test case 1: Regular text response
t.Run("RegularTextResponse", func(t *testing.T) { t.Run("RegularTextResponse", func(t *testing.T) {
@@ -119,7 +120,7 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Create a new openRouterParser // Create a new openRouterParser
parser := NewOpenRouterParser(logger) parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
// Test case 1: Regular text response // Test case 1: Regular text response
t.Run("RegularTextResponse", func(t *testing.T) { t.Run("RegularTextResponse", func(t *testing.T) {
@@ -190,3 +191,4 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
} }
}) })
} }