Enha: model name to config
This commit is contained in:
@@ -12,6 +12,7 @@ type Config struct {
|
||||
APIToken string `toml:"APIToken"`
|
||||
QuestionsPath string `toml:"QuestionsPath"`
|
||||
RPScenariosDir string `toml:"RPScenariosDir"`
|
||||
ModelName string `toml:"ModelName"`
|
||||
OutDir string `toml:"OutDir"`
|
||||
MaxTurns int `toml:"MaxTurns"`
|
||||
}
|
||||
@@ -30,6 +31,7 @@ func LoadConfigOrDefault(fn string) *Config {
|
||||
config.RPScenariosDir = "scenarios"
|
||||
config.OutDir = "results"
|
||||
config.MaxTurns = 10
|
||||
config.ModelName = "deepseek/deepseek-chat-v3-0324:free"
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
4
main.go
4
main.go
@@ -65,9 +65,9 @@ func loadRPScenarios(dir string) ([]*models.RPScenario, error) {
|
||||
|
||||
func chooseParser(apiURL string) RespParser {
|
||||
if apiURL == cfg.RemoteAPI {
|
||||
return NewOpenRouterParser(logger)
|
||||
return NewOpenRouterParser(logger, cfg)
|
||||
}
|
||||
return NewLCPRespParser(logger)
|
||||
return NewLCPRespParser(logger, cfg)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
19
parser.go
19
parser.go
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"grailbench/config"
|
||||
"grailbench/models"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -18,10 +19,11 @@ type RespParser interface {
|
||||
// DeepSeekParser: deepseek implementation of RespParser
|
||||
type deepSeekParser struct {
|
||||
log *slog.Logger
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser {
|
||||
return &deepSeekParser{log: log}
|
||||
func NewDeepSeekParser(log *slog.Logger, cfg *config.Config) *deepSeekParser {
|
||||
return &deepSeekParser{log: log, cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
||||
@@ -101,10 +103,11 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
||||
// llama.cpp implementation of RespParser
|
||||
type lcpRespParser struct {
|
||||
log *slog.Logger
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewLCPRespParser(log *slog.Logger) *lcpRespParser {
|
||||
return &lcpRespParser{log: log}
|
||||
func NewLCPRespParser(log *slog.Logger, cfg *config.Config) *lcpRespParser {
|
||||
return &lcpRespParser{log: log, cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
||||
@@ -151,14 +154,16 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
||||
|
||||
type openRouterParser struct {
|
||||
log *slog.Logger
|
||||
cfg *config.Config
|
||||
modelIndex uint32
|
||||
useChatAPI bool
|
||||
supportsTools bool
|
||||
}
|
||||
|
||||
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
||||
func NewOpenRouterParser(log *slog.Logger, cfg *config.Config) *openRouterParser {
|
||||
return &openRouterParser{
|
||||
log: log,
|
||||
cfg: cfg,
|
||||
modelIndex: 0,
|
||||
useChatAPI: false, // Default to completion API which is more widely supported
|
||||
supportsTools: false, // Don't assume tool support
|
||||
@@ -216,7 +221,7 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||
Model string `json:"model"`
|
||||
Messages []models.RoleMsg `json:"messages"`
|
||||
}{
|
||||
Model: "openai/gpt-4o-mini",
|
||||
Model: p.cfg.ModelName,
|
||||
Messages: []models.RoleMsg{
|
||||
{Role: "user", Content: prompt},
|
||||
},
|
||||
@@ -236,7 +241,7 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
}{
|
||||
Model: "openai/gpt-4o-mini",
|
||||
Model: p.cfg.ModelName,
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"grailbench/config"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
@@ -94,7 +95,7 @@ func TestParserToolDetection(t *testing.T) {
|
||||
|
||||
// Test DeepSeekParser
|
||||
t.Run("DeepSeekParser", func(t *testing.T) {
|
||||
parser := NewDeepSeekParser(logger)
|
||||
parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -112,7 +113,7 @@ func TestParserToolDetection(t *testing.T) {
|
||||
|
||||
// Test OpenRouterParser
|
||||
t.Run("OpenRouterParser", func(t *testing.T) {
|
||||
parser := NewOpenRouterParser(logger)
|
||||
parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -128,3 +129,4 @@ func TestParserToolDetection(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"grailbench/config"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
@@ -11,7 +12,7 @@ func TestDeepSeekParser_ParseBytes(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Create a new deepSeekParser
|
||||
parser := NewDeepSeekParser(logger)
|
||||
parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||
|
||||
// Test case 1: Regular text response
|
||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||
@@ -119,7 +120,7 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Create a new openRouterParser
|
||||
parser := NewOpenRouterParser(logger)
|
||||
parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||
|
||||
// Test case 1: Regular text response
|
||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||
@@ -190,3 +191,4 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user