Enha: model name to config
This commit is contained in:
@@ -12,6 +12,7 @@ type Config struct {
|
|||||||
APIToken string `toml:"APIToken"`
|
APIToken string `toml:"APIToken"`
|
||||||
QuestionsPath string `toml:"QuestionsPath"`
|
QuestionsPath string `toml:"QuestionsPath"`
|
||||||
RPScenariosDir string `toml:"RPScenariosDir"`
|
RPScenariosDir string `toml:"RPScenariosDir"`
|
||||||
|
ModelName string `toml:"ModelName"`
|
||||||
OutDir string `toml:"OutDir"`
|
OutDir string `toml:"OutDir"`
|
||||||
MaxTurns int `toml:"MaxTurns"`
|
MaxTurns int `toml:"MaxTurns"`
|
||||||
}
|
}
|
||||||
@@ -30,6 +31,7 @@ func LoadConfigOrDefault(fn string) *Config {
|
|||||||
config.RPScenariosDir = "scenarios"
|
config.RPScenariosDir = "scenarios"
|
||||||
config.OutDir = "results"
|
config.OutDir = "results"
|
||||||
config.MaxTurns = 10
|
config.MaxTurns = 10
|
||||||
|
config.ModelName = "deepseek/deepseek-chat-v3-0324:free"
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
4
main.go
4
main.go
@@ -65,9 +65,9 @@ func loadRPScenarios(dir string) ([]*models.RPScenario, error) {
|
|||||||
|
|
||||||
func chooseParser(apiURL string) RespParser {
|
func chooseParser(apiURL string) RespParser {
|
||||||
if apiURL == cfg.RemoteAPI {
|
if apiURL == cfg.RemoteAPI {
|
||||||
return NewOpenRouterParser(logger)
|
return NewOpenRouterParser(logger, cfg)
|
||||||
}
|
}
|
||||||
return NewLCPRespParser(logger)
|
return NewLCPRespParser(logger, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
33
parser.go
33
parser.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"grailbench/config"
|
||||||
"grailbench/models"
|
"grailbench/models"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -18,10 +19,11 @@ type RespParser interface {
|
|||||||
// DeepSeekParser: deepseek implementation of RespParser
|
// DeepSeekParser: deepseek implementation of RespParser
|
||||||
type deepSeekParser struct {
|
type deepSeekParser struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser {
|
func NewDeepSeekParser(log *slog.Logger, cfg *config.Config) *deepSeekParser {
|
||||||
return &deepSeekParser{log: log}
|
return &deepSeekParser{log: log, cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
||||||
@@ -101,10 +103,11 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
|||||||
// llama.cpp implementation of RespParser
|
// llama.cpp implementation of RespParser
|
||||||
type lcpRespParser struct {
|
type lcpRespParser struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLCPRespParser(log *slog.Logger) *lcpRespParser {
|
func NewLCPRespParser(log *slog.Logger, cfg *config.Config) *lcpRespParser {
|
||||||
return &lcpRespParser{log: log}
|
return &lcpRespParser{log: log, cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
||||||
@@ -150,17 +153,19 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type openRouterParser struct {
|
type openRouterParser struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
modelIndex uint32
|
cfg *config.Config
|
||||||
useChatAPI bool
|
modelIndex uint32
|
||||||
|
useChatAPI bool
|
||||||
supportsTools bool
|
supportsTools bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
func NewOpenRouterParser(log *slog.Logger, cfg *config.Config) *openRouterParser {
|
||||||
return &openRouterParser{
|
return &openRouterParser{
|
||||||
log: log,
|
log: log,
|
||||||
modelIndex: 0,
|
cfg: cfg,
|
||||||
useChatAPI: false, // Default to completion API which is more widely supported
|
modelIndex: 0,
|
||||||
|
useChatAPI: false, // Default to completion API which is more widely supported
|
||||||
supportsTools: false, // Don't assume tool support
|
supportsTools: false, // Don't assume tool support
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -213,10 +218,10 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
|||||||
if p.useChatAPI {
|
if p.useChatAPI {
|
||||||
// Use chat completions API with messages format (supports tool calls)
|
// Use chat completions API with messages format (supports tool calls)
|
||||||
payload := struct {
|
payload := struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []models.RoleMsg `json:"messages"`
|
Messages []models.RoleMsg `json:"messages"`
|
||||||
}{
|
}{
|
||||||
Model: "openai/gpt-4o-mini",
|
Model: p.cfg.ModelName,
|
||||||
Messages: []models.RoleMsg{
|
Messages: []models.RoleMsg{
|
||||||
{Role: "user", Content: prompt},
|
{Role: "user", Content: prompt},
|
||||||
},
|
},
|
||||||
@@ -236,7 +241,7 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
}{
|
}{
|
||||||
Model: "openai/gpt-4o-mini",
|
Model: p.cfg.ModelName,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"grailbench/config"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -94,7 +95,7 @@ func TestParserToolDetection(t *testing.T) {
|
|||||||
|
|
||||||
// Test DeepSeekParser
|
// Test DeepSeekParser
|
||||||
t.Run("DeepSeekParser", func(t *testing.T) {
|
t.Run("DeepSeekParser", func(t *testing.T) {
|
||||||
parser := NewDeepSeekParser(logger)
|
parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
@@ -112,7 +113,7 @@ func TestParserToolDetection(t *testing.T) {
|
|||||||
|
|
||||||
// Test OpenRouterParser
|
// Test OpenRouterParser
|
||||||
t.Run("OpenRouterParser", func(t *testing.T) {
|
t.Run("OpenRouterParser", func(t *testing.T) {
|
||||||
parser := NewOpenRouterParser(logger)
|
parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
@@ -128,3 +129,4 @@ func TestParserToolDetection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"grailbench/config"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -11,7 +12,7 @@ func TestDeepSeekParser_ParseBytes(t *testing.T) {
|
|||||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
// Create a new deepSeekParser
|
// Create a new deepSeekParser
|
||||||
parser := NewDeepSeekParser(logger)
|
parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
// Test case 1: Regular text response
|
// Test case 1: Regular text response
|
||||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||||
@@ -119,7 +120,7 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
|||||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
// Create a new openRouterParser
|
// Create a new openRouterParser
|
||||||
parser := NewOpenRouterParser(logger)
|
parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
// Test case 1: Regular text response
|
// Test case 1: Regular text response
|
||||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||||
@@ -190,3 +191,4 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user