Feat: rp mode
This commit is contained in:
@@ -7,10 +7,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
CurrentAPI string `toml:"CurrentAPI"`
|
RemoteAPI string `toml:"RemoteAPI"`
|
||||||
APIToken string `toml:"APIToken"`
|
LocalAPI string `toml:"LocalAPI"`
|
||||||
QuestionsPath string `toml:"QuestionsPath"`
|
APIToken string `toml:"APIToken"`
|
||||||
OutPath string `toml:"OutPath"`
|
QuestionsPath string `toml:"QuestionsPath"`
|
||||||
|
RPScenariosDir string `toml:"RPScenariosDir"`
|
||||||
|
OutDir string `toml:"OutDir"`
|
||||||
|
MaxTurns int `toml:"MaxTurns"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfigOrDefault(fn string) *Config {
|
func LoadConfigOrDefault(fn string) *Config {
|
||||||
@@ -21,9 +24,12 @@ func LoadConfigOrDefault(fn string) *Config {
|
|||||||
_, err := toml.DecodeFile(fn, &config)
|
_, err := toml.DecodeFile(fn, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("failed to read config from file, loading default", "error", err)
|
fmt.Println("failed to read config from file, loading default", "error", err)
|
||||||
config.CurrentAPI = "http://localhost:8080/completion"
|
config.RemoteAPI = "https://openrouter.ai/api/v1/completions"
|
||||||
|
config.LocalAPI = "http://localhost:8080/completion"
|
||||||
config.QuestionsPath = "data/questions.json"
|
config.QuestionsPath = "data/questions.json"
|
||||||
config.OutPath = "data/out.json"
|
config.RPScenariosDir = "scenarios"
|
||||||
|
config.OutDir = "results"
|
||||||
|
config.MaxTurns = 10
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
234
main.go
234
main.go
@@ -1,13 +1,17 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"grailbench/config"
|
"grailbench/config"
|
||||||
@@ -15,10 +19,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
httpClient = &http.Client{}
|
httpClient = &http.Client{}
|
||||||
currentModel = ""
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadQuestions(fp string) ([]models.Question, error) {
|
func loadQuestions(fp string) ([]models.Question, error) {
|
||||||
@@ -35,7 +38,35 @@ func loadQuestions(fp string) ([]models.Question, error) {
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func chooseParser() RespParser {
|
func loadRPScenarios(dir string) ([]*models.RPScenario, error) {
|
||||||
|
scenarios := []*models.RPScenario{}
|
||||||
|
files, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
if !file.IsDir() && strings.HasSuffix(file.Name(), ".json") {
|
||||||
|
fp := filepath.Join(dir, file.Name())
|
||||||
|
data, err := os.ReadFile(fp)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to read scenario file", "error", err, "fp", fp)
|
||||||
|
continue // Skip this file
|
||||||
|
}
|
||||||
|
scenario := &models.RPScenario{}
|
||||||
|
if err := json.Unmarshal(data, &scenario); err != nil {
|
||||||
|
logger.Error("failed to unmarshal scenario file", "error", err, "fp", fp)
|
||||||
|
continue // Skip this file
|
||||||
|
}
|
||||||
|
scenarios = append(scenarios, scenario)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scenarios, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func chooseParser(apiURL string) RespParser {
|
||||||
|
if apiURL == cfg.RemoteAPI {
|
||||||
|
return NewOpenRouterParser(logger)
|
||||||
|
}
|
||||||
return NewLCPRespParser(logger)
|
return NewLCPRespParser(logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,109 +79,188 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
questions, err := loadQuestions(cfg.QuestionsPath)
|
mode := flag.String("mode", "questions", "Benchmark mode: 'questions' or 'rp'")
|
||||||
if err != nil {
|
flag.Parse()
|
||||||
panic(err)
|
|
||||||
|
switch *mode {
|
||||||
|
case "questions":
|
||||||
|
questions, err := loadQuestions(cfg.QuestionsPath)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
answers, err := runBench(questions)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
data, err := json.MarshalIndent(answers, " ", " ")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(cfg.OutDir, "results.json"), data, 0666); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
case "rp":
|
||||||
|
scenarios, err := loadRPScenarios(cfg.RPScenariosDir)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
for _, scenario := range scenarios {
|
||||||
|
logger.Info("running scenario", "name", scenario.Name)
|
||||||
|
conversation, err := runRPBench(scenario)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to run rp bench", "error", err, "scenario", scenario.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := saveRPResult(scenario.Name, conversation); err != nil {
|
||||||
|
logger.Error("failed to save rp result", "error", err, "scenario", scenario.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Println("Unknown mode, please use 'questions' or 'rp'")
|
||||||
}
|
}
|
||||||
answers, err := runBench(questions)
|
}
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
func saveRPResult(name string, conversation []models.RPMessage) error {
|
||||||
|
date := time.Now().Format("2006-01-02")
|
||||||
|
fn := fmt.Sprintf("%s_%s.txt", name, date)
|
||||||
|
fp := filepath.Join(cfg.OutDir, fn)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
for _, msg := range conversation {
|
||||||
|
b.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Author, msg.Content))
|
||||||
}
|
}
|
||||||
data, err := json.MarshalIndent(answers, " ", " ")
|
|
||||||
if err != nil {
|
return os.WriteFile(fp, []byte(b.String()), 0666)
|
||||||
panic(err)
|
}
|
||||||
|
|
||||||
|
func runRPBench(scenario *models.RPScenario) ([]models.RPMessage, error) {
|
||||||
|
conversation := []models.RPMessage{}
|
||||||
|
if scenario.FirstMsg != "" {
|
||||||
|
conversation = append(conversation, models.RPMessage{Author: "Scenario", Content: scenario.FirstMsg})
|
||||||
}
|
}
|
||||||
if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil {
|
|
||||||
panic(err)
|
for i := 0; i < cfg.MaxTurns; i++ {
|
||||||
|
// User's turn (local model)
|
||||||
|
userPrompt := buildRPPrompt(scenario.UserSysPrompt, conversation, scenario.UserName)
|
||||||
|
userResp, err := callLLM(userPrompt, cfg.LocalAPI)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("user turn failed: %w", err)
|
||||||
|
}
|
||||||
|
parser := chooseParser(cfg.LocalAPI)
|
||||||
|
userText, err := parser.ParseBytes(userResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("user turn failed parsing response: %w", err)
|
||||||
|
}
|
||||||
|
conversation = append(conversation, models.RPMessage{Author: scenario.UserName, Content: userText})
|
||||||
|
|
||||||
|
// Character's turn (remote model)
|
||||||
|
charPrompt := buildRPPrompt(scenario.CharSysPrompt, conversation, scenario.CharName)
|
||||||
|
charResp, err := callLLM(charPrompt, cfg.RemoteAPI)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("char turn failed: %w", err)
|
||||||
|
}
|
||||||
|
parser = chooseParser(cfg.RemoteAPI)
|
||||||
|
charText, err := parser.ParseBytes(charResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("char turn failed parsing response: %w", err)
|
||||||
|
}
|
||||||
|
conversation = append(conversation, models.RPMessage{Author: scenario.CharName, Content: charText})
|
||||||
}
|
}
|
||||||
|
return conversation, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(sysPrompt)
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
for _, msg := range history {
|
||||||
|
b.WriteString(fmt.Sprintf("%s: %s\n", msg.Author, msg.Content))
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("%s:", characterName))
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func runBench(questions []models.Question) ([]models.Answer, error) {
|
func runBench(questions []models.Question) ([]models.Answer, error) {
|
||||||
answers := []models.Answer{}
|
answers := []models.Answer{}
|
||||||
for _, q := range questions {
|
for _, q := range questions {
|
||||||
resp, err := callLLM(buildPrompt(q.Question))
|
resp, err := callLLM(buildPrompt(q.Question), cfg.RemoteAPI) // Assuming remote for questions
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// log
|
logger.Error("failed to call llm", "error", err)
|
||||||
panic(err)
|
continue
|
||||||
}
|
}
|
||||||
parser := chooseParser()
|
parser := chooseParser(cfg.RemoteAPI)
|
||||||
respText, err := parser.ParseBytes(resp)
|
respText, err := parser.ParseBytes(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
logger.Error("failed to parse llm response", "error", err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
a := models.Answer{Q: q, Answer: respText, Model: currentModel}
|
a := models.Answer{Q: q, Answer: respText, Model: "default"} // model name from resp?
|
||||||
answers = append(answers, a)
|
answers = append(answers, a)
|
||||||
}
|
}
|
||||||
return answers, nil
|
return answers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// openai vs completion
|
|
||||||
func buildPrompt(q string) string {
|
func buildPrompt(q string) string {
|
||||||
// sure injection?
|
|
||||||
// completion
|
|
||||||
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
|
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func callLLM(prompt string) ([]byte, error) {
|
func callLLM(prompt string, apiURL string) ([]byte, error) {
|
||||||
method := "POST"
|
method := "POST"
|
||||||
// Generate the payload once as bytes
|
parser := chooseParser(apiURL)
|
||||||
parser := chooseParser()
|
|
||||||
payloadReader := parser.MakePayload(prompt)
|
payloadReader := parser.MakePayload(prompt)
|
||||||
|
|
||||||
|
var payloadBytes []byte
|
||||||
|
if payloadReader != nil {
|
||||||
|
var err error
|
||||||
|
payloadBytes, err = io.ReadAll(payloadReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read payload: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
maxRetries := 6
|
maxRetries := 6
|
||||||
baseDelay := 2 // seconds
|
baseDelay := 2 * time.Second
|
||||||
|
|
||||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
// Create a new request for the attempt
|
req, err := http.NewRequest(method, apiURL, bytes.NewReader(payloadBytes))
|
||||||
req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if attempt == maxRetries-1 {
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err)
|
|
||||||
}
|
|
||||||
logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
|
||||||
time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1))
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Accept", "application/json")
|
req.Header.Add("Accept", "application/json")
|
||||||
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
|
if apiURL == cfg.RemoteAPI {
|
||||||
|
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if attempt == maxRetries-1 {
|
if attempt == maxRetries-1 {
|
||||||
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
|
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
|
||||||
}
|
}
|
||||||
logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
logger.Error("http request failed; will retry", "error", err, "url", apiURL, "attempt", attempt)
|
||||||
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
time.Sleep(baseDelay * time.Duration(1<<attempt))
|
||||||
time.Sleep(delay)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
if attempt == maxRetries-1 {
|
if attempt == maxRetries-1 {
|
||||||
return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err)
|
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d, body: %s", maxRetries, resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt, "body", string(body))
|
||||||
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
time.Sleep(baseDelay * time.Duration(1<<attempt))
|
||||||
time.Sleep(delay)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Check status code
|
|
||||||
if resp.StatusCode >= 400 && resp.StatusCode < 600 {
|
logger.Debug("llm resp", "body", string(body), "url", apiURL, "attempt", attempt)
|
||||||
if attempt == maxRetries-1 {
|
|
||||||
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode)
|
|
||||||
}
|
|
||||||
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt)
|
|
||||||
delay := time.Duration((baseDelay * (1 << attempt))) * time.Second
|
|
||||||
time.Sleep(delay)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
// For non-retriable errors, return immediately
|
|
||||||
return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
// Success
|
|
||||||
logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt)
|
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("unknown error in retry loop")
|
return nil, errors.New("exceeded max retries")
|
||||||
}
|
}
|
||||||
|
@@ -67,3 +67,21 @@ type LLMResp struct {
|
|||||||
StoppingWord string `json:"stopping_word"`
|
StoppingWord string `json:"stopping_word"`
|
||||||
TokensCached int `json:"tokens_cached"`
|
TokensCached int `json:"tokens_cached"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type LCPResp struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RPScenario struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
CharSysPrompt string `json:"char_sys_prompt"`
|
||||||
|
UserSysPrompt string `json:"user_sys_prompt"`
|
||||||
|
CharName string `json:"char_name"`
|
||||||
|
UserName string `json:"user_name"`
|
||||||
|
FirstMsg string `json:"first_msg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RPMessage struct {
|
||||||
|
Author string `json:"author"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
24
parser.go
24
parser.go
@@ -105,7 +105,7 @@ func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
|||||||
|
|
||||||
func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
||||||
// parsing logic here
|
// parsing logic here
|
||||||
resp := models.OpenRouterResp{}
|
resp := models.DSResp{}
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
p.log.Error("failed to unmarshal", "error", err)
|
p.log.Error("failed to unmarshal", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -115,33 +115,25 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
|||||||
err := errors.New("empty choices in resp")
|
err := errors.New("empty choices in resp")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
text := resp.Choices[0].Message.Content
|
text := resp.Choices[0].Text
|
||||||
return text, nil
|
return text, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||||
// Models to rotate through
|
// Models to rotate through
|
||||||
models := []string{
|
models := []string{
|
||||||
"google/gemini-2.0-flash-exp:free",
|
"google/gemini-flash-1.5",
|
||||||
"deepseek/deepseek-chat-v3-0324:free",
|
"deepseek/deepseek-coder",
|
||||||
"mistralai/mistral-small-3.2-24b-instruct:free",
|
"mistralai/mistral-7b-instruct",
|
||||||
"qwen/qwen3-14b:free",
|
"qwen/qwen-72b-chat",
|
||||||
"deepseek/deepseek-r1:free",
|
"meta-llama/llama-3-8b-instruct",
|
||||||
"google/gemma-3-27b-it:free",
|
|
||||||
"meta-llama/llama-3.3-70b-instruct:free",
|
|
||||||
}
|
}
|
||||||
// Get next model index using atomic addition for thread safety
|
// Get next model index using atomic addition for thread safety
|
||||||
p.modelIndex++
|
p.modelIndex++
|
||||||
model := models[int(p.modelIndex)%len(models)]
|
model := models[int(p.modelIndex)%len(models)]
|
||||||
strPayload := fmt.Sprintf(`{
|
strPayload := fmt.Sprintf(`{
|
||||||
"model": "%s",
|
"model": "%s",
|
||||||
"max_tokens": 300,
|
"prompt": "%s"
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "%s"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`, model, prompt)
|
}`, model, prompt)
|
||||||
p.log.Debug("made openrouter payload", "model", model, "payload", strPayload)
|
p.log.Debug("made openrouter payload", "model", model, "payload", strPayload)
|
||||||
return strings.NewReader(strPayload)
|
return strings.NewReader(strPayload)
|
||||||
|
Reference in New Issue
Block a user