Enha: model name to config
This commit is contained in:
@@ -12,6 +12,7 @@ type Config struct {
|
|||||||
APIToken string `toml:"APIToken"`
|
APIToken string `toml:"APIToken"`
|
||||||
QuestionsPath string `toml:"QuestionsPath"`
|
QuestionsPath string `toml:"QuestionsPath"`
|
||||||
RPScenariosDir string `toml:"RPScenariosDir"`
|
RPScenariosDir string `toml:"RPScenariosDir"`
|
||||||
|
ModelName string `toml:"ModelName"`
|
||||||
OutDir string `toml:"OutDir"`
|
OutDir string `toml:"OutDir"`
|
||||||
MaxTurns int `toml:"MaxTurns"`
|
MaxTurns int `toml:"MaxTurns"`
|
||||||
}
|
}
|
||||||
@@ -30,6 +31,7 @@ func LoadConfigOrDefault(fn string) *Config {
|
|||||||
config.RPScenariosDir = "scenarios"
|
config.RPScenariosDir = "scenarios"
|
||||||
config.OutDir = "results"
|
config.OutDir = "results"
|
||||||
config.MaxTurns = 10
|
config.MaxTurns = 10
|
||||||
|
config.ModelName = "deepseek/deepseek-chat-v3-0324:free"
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
4
main.go
4
main.go
@@ -65,9 +65,9 @@ func loadRPScenarios(dir string) ([]*models.RPScenario, error) {
|
|||||||
|
|
||||||
func chooseParser(apiURL string) RespParser {
|
func chooseParser(apiURL string) RespParser {
|
||||||
if apiURL == cfg.RemoteAPI {
|
if apiURL == cfg.RemoteAPI {
|
||||||
return NewOpenRouterParser(logger)
|
return NewOpenRouterParser(logger, cfg)
|
||||||
}
|
}
|
||||||
return NewLCPRespParser(logger)
|
return NewLCPRespParser(logger, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
51
parser.go
51
parser.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"grailbench/config"
|
||||||
"grailbench/models"
|
"grailbench/models"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -18,10 +19,11 @@ type RespParser interface {
|
|||||||
// DeepSeekParser: deepseek implementation of RespParser
|
// DeepSeekParser: deepseek implementation of RespParser
|
||||||
type deepSeekParser struct {
|
type deepSeekParser struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser {
|
func NewDeepSeekParser(log *slog.Logger, cfg *config.Config) *deepSeekParser {
|
||||||
return &deepSeekParser{log: log}
|
return &deepSeekParser{log: log, cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
||||||
@@ -36,10 +38,10 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
|||||||
err := errors.New("empty choices in dsResp")
|
err := errors.New("empty choices in dsResp")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the response contains tool calls
|
// Check if the response contains tool calls
|
||||||
choice := dsResp.Choices[0]
|
choice := dsResp.Choices[0]
|
||||||
|
|
||||||
// Handle response with message field (OpenAI format)
|
// Handle response with message field (OpenAI format)
|
||||||
if choice.Message.Role != "" {
|
if choice.Message.Role != "" {
|
||||||
if len(choice.Message.ToolCalls) > 0 {
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
@@ -51,7 +53,7 @@ func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
|||||||
// Regular text response
|
// Regular text response
|
||||||
return choice.Message.Content, nil
|
return choice.Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle response with text field (legacy format)
|
// Handle response with text field (legacy format)
|
||||||
return choice.Text, nil
|
return choice.Text, nil
|
||||||
}
|
}
|
||||||
@@ -101,10 +103,11 @@ func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
|||||||
// llama.cpp implementation of RespParser
|
// llama.cpp implementation of RespParser
|
||||||
type lcpRespParser struct {
|
type lcpRespParser struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLCPRespParser(log *slog.Logger) *lcpRespParser {
|
func NewLCPRespParser(log *slog.Logger, cfg *config.Config) *lcpRespParser {
|
||||||
return &lcpRespParser{log: log}
|
return &lcpRespParser{log: log, cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
||||||
@@ -150,17 +153,19 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type openRouterParser struct {
|
type openRouterParser struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
modelIndex uint32
|
cfg *config.Config
|
||||||
useChatAPI bool
|
modelIndex uint32
|
||||||
|
useChatAPI bool
|
||||||
supportsTools bool
|
supportsTools bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
func NewOpenRouterParser(log *slog.Logger, cfg *config.Config) *openRouterParser {
|
||||||
return &openRouterParser{
|
return &openRouterParser{
|
||||||
log: log,
|
log: log,
|
||||||
modelIndex: 0,
|
cfg: cfg,
|
||||||
useChatAPI: false, // Default to completion API which is more widely supported
|
modelIndex: 0,
|
||||||
|
useChatAPI: false, // Default to completion API which is more widely supported
|
||||||
supportsTools: false, // Don't assume tool support
|
supportsTools: false, // Don't assume tool support
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,9 +183,9 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
|||||||
err := errors.New("empty choices in openrouter chat response")
|
err := errors.New("empty choices in openrouter chat response")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := resp.Choices[0]
|
choice := resp.Choices[0]
|
||||||
|
|
||||||
// Check if the response contains tool calls
|
// Check if the response contains tool calls
|
||||||
if len(choice.Message.ToolCalls) > 0 {
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
// Handle tool call response
|
// Handle tool call response
|
||||||
@@ -188,11 +193,11 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
|||||||
// Return a special marker indicating tool usage
|
// Return a special marker indicating tool usage
|
||||||
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
|
return fmt.Sprintf("[TOOL_CALL:%s]", toolCall.Function.Name), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Regular text response
|
// Regular text response
|
||||||
return choice.Message.Content, nil
|
return choice.Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If using completion API, parse as text completion response (no tool calls)
|
// If using completion API, parse as text completion response (no tool calls)
|
||||||
resp := models.ORCompletionResp{}
|
resp := models.ORCompletionResp{}
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
@@ -204,7 +209,7 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
|||||||
err := errors.New("empty choices in openrouter completion response")
|
err := errors.New("empty choices in openrouter completion response")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the text content
|
// Return the text content
|
||||||
return resp.Choices[0].Text, nil
|
return resp.Choices[0].Text, nil
|
||||||
}
|
}
|
||||||
@@ -213,10 +218,10 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
|||||||
if p.useChatAPI {
|
if p.useChatAPI {
|
||||||
// Use chat completions API with messages format (supports tool calls)
|
// Use chat completions API with messages format (supports tool calls)
|
||||||
payload := struct {
|
payload := struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []models.RoleMsg `json:"messages"`
|
Messages []models.RoleMsg `json:"messages"`
|
||||||
}{
|
}{
|
||||||
Model: "openai/gpt-4o-mini",
|
Model: p.cfg.ModelName,
|
||||||
Messages: []models.RoleMsg{
|
Messages: []models.RoleMsg{
|
||||||
{Role: "user", Content: prompt},
|
{Role: "user", Content: prompt},
|
||||||
},
|
},
|
||||||
@@ -230,13 +235,13 @@ func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
|||||||
p.log.Debug("made openrouter chat payload", "payload", string(b))
|
p.log.Debug("made openrouter chat payload", "payload", string(b))
|
||||||
return bytes.NewReader(b)
|
return bytes.NewReader(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use completions API with prompt format (no tool calls)
|
// Use completions API with prompt format (no tool calls)
|
||||||
payload := struct {
|
payload := struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
}{
|
}{
|
||||||
Model: "openai/gpt-4o-mini",
|
Model: p.cfg.ModelName,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"grailbench/config"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -94,8 +95,8 @@ func TestParserToolDetection(t *testing.T) {
|
|||||||
|
|
||||||
// Test DeepSeekParser
|
// Test DeepSeekParser
|
||||||
t.Run("DeepSeekParser", func(t *testing.T) {
|
t.Run("DeepSeekParser", func(t *testing.T) {
|
||||||
parser := NewDeepSeekParser(logger)
|
parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
||||||
@@ -112,8 +113,8 @@ func TestParserToolDetection(t *testing.T) {
|
|||||||
|
|
||||||
// Test OpenRouterParser
|
// Test OpenRouterParser
|
||||||
t.Run("OpenRouterParser", func(t *testing.T) {
|
t.Run("OpenRouterParser", func(t *testing.T) {
|
||||||
parser := NewOpenRouterParser(logger)
|
parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
result, err := parser.ParseBytes([]byte(tc.responseJSON))
|
||||||
@@ -127,4 +128,5 @@ func TestParserToolDetection(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"grailbench/config"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -11,7 +12,7 @@ func TestDeepSeekParser_ParseBytes(t *testing.T) {
|
|||||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
// Create a new deepSeekParser
|
// Create a new deepSeekParser
|
||||||
parser := NewDeepSeekParser(logger)
|
parser := NewDeepSeekParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
// Test case 1: Regular text response
|
// Test case 1: Regular text response
|
||||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||||
@@ -119,7 +120,7 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
|||||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
// Create a new openRouterParser
|
// Create a new openRouterParser
|
||||||
parser := NewOpenRouterParser(logger)
|
parser := NewOpenRouterParser(logger, config.LoadConfigOrDefault("config.toml"))
|
||||||
|
|
||||||
// Test case 1: Regular text response
|
// Test case 1: Regular text response
|
||||||
t.Run("RegularTextResponse", func(t *testing.T) {
|
t.Run("RegularTextResponse", func(t *testing.T) {
|
||||||
@@ -189,4 +190,5 @@ func TestOpenRouterParser_ParseBytes(t *testing.T) {
|
|||||||
t.Errorf("Expected %s, got %s", expected, result)
|
t.Errorf("Expected %s, got %s", expected, result)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user