From e499a1ae370c6cebf6e93ca9161ffce14de99af8 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Sat, 9 Aug 2025 11:46:33 +0300 Subject: [PATCH] Feat: rp mode --- config/config.go | 18 ++-- main.go | 234 ++++++++++++++++++++++++++++++++++------------- models/models.go | 18 ++++ parser.go | 24 ++--- 4 files changed, 210 insertions(+), 84 deletions(-) diff --git a/config/config.go b/config/config.go index 2887834..24daa47 100644 --- a/config/config.go +++ b/config/config.go @@ -7,10 +7,13 @@ import ( ) type Config struct { - CurrentAPI string `toml:"CurrentAPI"` - APIToken string `toml:"APIToken"` - QuestionsPath string `toml:"QuestionsPath"` - OutPath string `toml:"OutPath"` + RemoteAPI string `toml:"RemoteAPI"` + LocalAPI string `toml:"LocalAPI"` + APIToken string `toml:"APIToken"` + QuestionsPath string `toml:"QuestionsPath"` + RPScenariosDir string `toml:"RPScenariosDir"` + OutDir string `toml:"OutDir"` + MaxTurns int `toml:"MaxTurns"` } func LoadConfigOrDefault(fn string) *Config { @@ -21,9 +24,12 @@ func LoadConfigOrDefault(fn string) *Config { _, err := toml.DecodeFile(fn, &config) if err != nil { fmt.Println("failed to read config from file, loading default", "error", err) - config.CurrentAPI = "http://localhost:8080/completion" + config.RemoteAPI = "https://openrouter.ai/api/v1/completions" + config.LocalAPI = "http://localhost:8080/completion" config.QuestionsPath = "data/questions.json" - config.OutPath = "data/out.json" + config.RPScenariosDir = "scenarios" + config.OutDir = "results" + config.MaxTurns = 10 } return config } diff --git a/main.go b/main.go index ce658e6..f6ba56d 100644 --- a/main.go +++ b/main.go @@ -1,13 +1,17 @@ package main import ( + "bytes" "encoding/json" "errors" + "flag" "fmt" "io" "log/slog" "net/http" "os" + "path/filepath" + "strings" "time" "grailbench/config" @@ -15,10 +19,9 @@ import ( ) var ( - logger *slog.Logger - cfg *config.Config - httpClient = &http.Client{} - currentModel = "" + logger *slog.Logger + cfg *config.Config + httpClient = &http.Client{} ) func loadQuestions(fp string) ([]models.Question, error) { @@ -35,7 +38,35 @@ func loadQuestions(fp string) ([]models.Question, error) { return resp, nil } -func chooseParser() RespParser { +func loadRPScenarios(dir string) ([]*models.RPScenario, error) { + scenarios := []*models.RPScenario{} + files, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + for _, file := range files { + if !file.IsDir() && strings.HasSuffix(file.Name(), ".json") { + fp := filepath.Join(dir, file.Name()) + data, err := os.ReadFile(fp) + if err != nil { + logger.Error("failed to read scenario file", "error", err, "fp", fp) + continue // Skip this file + } + scenario := &models.RPScenario{} + if err := json.Unmarshal(data, &scenario); err != nil { + logger.Error("failed to unmarshal scenario file", "error", err, "fp", fp) + continue // Skip this file + } + scenarios = append(scenarios, scenario) + } + } + return scenarios, nil +} + +func chooseParser(apiURL string) RespParser { + if apiURL == cfg.RemoteAPI { + return NewOpenRouterParser(logger) + } return NewLCPRespParser(logger) } @@ -48,109 +79,188 @@ func init() { } func main() { - questions, err := loadQuestions(cfg.QuestionsPath) - if err != nil { - panic(err) + mode := flag.String("mode", "questions", "Benchmark mode: 'questions' or 'rp'") + flag.Parse() + + switch *mode { + case "questions": + questions, err := loadQuestions(cfg.QuestionsPath) + if err != nil { + panic(err) + } + answers, err := runBench(questions) + if err != nil { + panic(err) + } + data, err := json.MarshalIndent(answers, " ", " ") + if err != nil { + panic(err) + } + if err := os.WriteFile(filepath.Join(cfg.OutDir, "results.json"), data, 0666); err != nil { + panic(err) + } + case "rp": + scenarios, err := loadRPScenarios(cfg.RPScenariosDir) + if err != nil { + panic(err) + } + for _, scenario := range scenarios { + logger.Info("running scenario", "name", scenario.Name) + conversation, err := runRPBench(scenario) + if err != nil { + logger.Error("failed to run rp bench", "error", err, "scenario", scenario.Name) + continue + } + if err := saveRPResult(scenario.Name, conversation); err != nil { + logger.Error("failed to save rp result", "error", err, "scenario", scenario.Name) + } + } + default: + fmt.Println("Unknown mode, please use 'questions' or 'rp'") } - answers, err := runBench(questions) - if err != nil { - panic(err) +} + +func saveRPResult(name string, conversation []models.RPMessage) error { + date := time.Now().Format("2006-01-02") + fn := fmt.Sprintf("%s_%s.txt", name, date) + fp := filepath.Join(cfg.OutDir, fn) + + var b strings.Builder + for _, msg := range conversation { + b.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Author, msg.Content)) } - data, err := json.MarshalIndent(answers, " ", " ") - if err != nil { - panic(err) + + return os.WriteFile(fp, []byte(b.String()), 0666) +} + +func runRPBench(scenario *models.RPScenario) ([]models.RPMessage, error) { + conversation := []models.RPMessage{} + if scenario.FirstMsg != "" { + conversation = append(conversation, models.RPMessage{Author: "Scenario", Content: scenario.FirstMsg}) } - if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil { - panic(err) + + for i := 0; i < cfg.MaxTurns; i++ { + // User's turn (local model) + userPrompt := buildRPPrompt(scenario.UserSysPrompt, conversation, scenario.UserName) + userResp, err := callLLM(userPrompt, cfg.LocalAPI) + if err != nil { + return nil, fmt.Errorf("user turn failed: %w", err) + } + parser := chooseParser(cfg.LocalAPI) + userText, err := parser.ParseBytes(userResp) + if err != nil { + return nil, fmt.Errorf("user turn failed parsing response: %w", err) + } + conversation = append(conversation, models.RPMessage{Author: scenario.UserName, Content: userText}) + + // Character's turn (remote model) + charPrompt := buildRPPrompt(scenario.CharSysPrompt, conversation, scenario.CharName) + charResp, err := callLLM(charPrompt, cfg.RemoteAPI) + if err != nil { + return nil, fmt.Errorf("char turn failed: %w", err) + } + parser = chooseParser(cfg.RemoteAPI) + charText, err := parser.ParseBytes(charResp) + if err != nil { + return nil, fmt.Errorf("char turn failed parsing response: %w", err) + } + conversation = append(conversation, models.RPMessage{Author: scenario.CharName, Content: charText}) } + return conversation, nil +} + +func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName string) string { + var b strings.Builder + b.WriteString(sysPrompt) + b.WriteString("\n\n") + for _, msg := range history { + b.WriteString(fmt.Sprintf("%s: %s\n", msg.Author, msg.Content)) + } + b.WriteString(fmt.Sprintf("%s:", characterName)) + return b.String() } func runBench(questions []models.Question) ([]models.Answer, error) { answers := []models.Answer{} for _, q := range questions { - resp, err := callLLM(buildPrompt(q.Question)) + resp, err := callLLM(buildPrompt(q.Question), cfg.RemoteAPI) // Assuming remote for questions if err != nil { - // log - panic(err) + logger.Error("failed to call llm", "error", err) + continue } - parser := chooseParser() + parser := chooseParser(cfg.RemoteAPI) respText, err := parser.ParseBytes(resp) if err != nil { - panic(err) + logger.Error("failed to parse llm response", "error", err) + continue } - a := models.Answer{Q: q, Answer: respText, Model: currentModel} + a := models.Answer{Q: q, Answer: respText, Model: "default"} // model name from resp? answers = append(answers, a) } return answers, nil } -// openai vs completion func buildPrompt(q string) string { - // sure injection? - // completion return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q) } -func callLLM(prompt string) ([]byte, error) { +func callLLM(prompt string, apiURL string) ([]byte, error) { method := "POST" - // Generate the payload once as bytes - parser := chooseParser() + parser := chooseParser(apiURL) payloadReader := parser.MakePayload(prompt) + + var payloadBytes []byte + if payloadReader != nil { + var err error + payloadBytes, err = io.ReadAll(payloadReader) + if err != nil { + return nil, fmt.Errorf("failed to read payload: %w", err) + } + } + client := &http.Client{} maxRetries := 6 - baseDelay := 2 // seconds + baseDelay := 2 * time.Second + for attempt := 0; attempt < maxRetries; attempt++ { - // Create a new request for the attempt - req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader) + req, err := http.NewRequest(method, apiURL, bytes.NewReader(payloadBytes)) if err != nil { - if attempt == maxRetries-1 { - return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err) - } - logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt) - time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1)) - continue + return nil, fmt.Errorf("failed to create request: %w", err) } + req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - req.Header.Add("Authorization", "Bearer "+cfg.APIToken) + if apiURL == cfg.RemoteAPI { + req.Header.Add("Authorization", "Bearer "+cfg.APIToken) + } + resp, err := client.Do(req) if err != nil { if attempt == maxRetries-1 { return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err) } - logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt) - delay := time.Duration(baseDelay*(attempt+1)) * time.Second - time.Sleep(delay) + logger.Error("http request failed; will retry", "error", err, "url", apiURL, "attempt", attempt) + time.Sleep(baseDelay * time.Duration(1<= 400 { if attempt == maxRetries-1 { - return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err) + return nil, fmt.Errorf("LLM call failed after %d retries, got status %d, body: %s", maxRetries, resp.StatusCode, string(body)) } - logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt) - delay := time.Duration(baseDelay*(attempt+1)) * time.Second - time.Sleep(delay) + logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt, "body", string(body)) + time.Sleep(baseDelay * time.Duration(1<= 400 && resp.StatusCode < 600 { - if attempt == maxRetries-1 { - return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode) - } - logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt) - delay := time.Duration((baseDelay * (1 << attempt))) * time.Second - time.Sleep(delay) - continue - } - if resp.StatusCode != http.StatusOK { - // For non-retriable errors, return immediately - return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body)) - } - // Success - logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt) + + logger.Debug("llm resp", "body", string(body), "url", apiURL, "attempt", attempt) return body, nil } - return nil, errors.New("unknown error in retry loop") + return nil, errors.New("exceeded max retries") } diff --git a/models/models.go b/models/models.go index 0402400..80cbe2c 100644 --- a/models/models.go +++ b/models/models.go @@ -67,3 +67,21 @@ type LLMResp struct { StoppingWord string `json:"stopping_word"` TokensCached int `json:"tokens_cached"` } + +type LCPResp struct { + Content string `json:"content"` +} + +type RPScenario struct { + Name string `json:"name"` + CharSysPrompt string `json:"char_sys_prompt"` + UserSysPrompt string `json:"user_sys_prompt"` + CharName string `json:"char_name"` + UserName string `json:"user_name"` + FirstMsg string `json:"first_msg"` +} + +type RPMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} diff --git a/parser.go b/parser.go index 84e6700..e726e5a 100644 --- a/parser.go +++ b/parser.go @@ -105,7 +105,7 @@ func NewOpenRouterParser(log *slog.Logger) *openRouterParser { func (p *openRouterParser) ParseBytes(body []byte) (string, error) { // parsing logic here - resp := models.OpenRouterResp{} + resp := models.DSResp{} if err := json.Unmarshal(body, &resp); err != nil { p.log.Error("failed to unmarshal", "error", err) return "", err @@ -115,33 +115,25 @@ func (p *openRouterParser) ParseBytes(body []byte) (string, error) { err := errors.New("empty choices in resp") return "", err } - text := resp.Choices[0].Message.Content + text := resp.Choices[0].Text return text, nil } func (p *openRouterParser) MakePayload(prompt string) io.Reader { // Models to rotate through models := []string{ - "google/gemini-2.0-flash-exp:free", - "deepseek/deepseek-chat-v3-0324:free", - "mistralai/mistral-small-3.2-24b-instruct:free", - "qwen/qwen3-14b:free", - "deepseek/deepseek-r1:free", - "google/gemma-3-27b-it:free", - "meta-llama/llama-3.3-70b-instruct:free", + "google/gemini-flash-1.5", + "deepseek/deepseek-coder", + "mistralai/mistral-7b-instruct", + "qwen/qwen-72b-chat", + "meta-llama/llama-3-8b-instruct", } // Get next model index using atomic addition for thread safety p.modelIndex++ model := models[int(p.modelIndex)%len(models)] strPayload := fmt.Sprintf(`{ "model": "%s", - "max_tokens": 300, - "messages": [ - { - "role": "user", - "content": "%s" - } - ] + "prompt": "%s" }`, model, prompt) p.log.Debug("made openrouter payload", "model", model, "payload", strPayload) return strings.NewReader(strPayload)