Files
grailbench/main.go
Grail Finder 57b808cb31 Feat: tools
2025-08-27 10:03:09 +03:00

266 lines
7.4 KiB
Go

package main
import (
"bytes"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"grailbench/config"
"grailbench/models"
)
var (
logger *slog.Logger
cfg *config.Config
httpClient = &http.Client{}
)
func loadQuestions(fp string) ([]models.Question, error) {
data, err := os.ReadFile(fp)
if err != nil {
logger.Error("failed to read file", "error", err, "fp", fp)
return nil, err
}
resp := []models.Question{}
if err := json.Unmarshal(data, &resp); err != nil {
logger.Error("failed to unmarshal file", "error", err, "fp", fp)
return nil, err
}
return resp, nil
}
func loadRPScenarios(dir string) ([]*models.RPScenario, error) {
scenarios := []*models.RPScenario{}
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
for _, file := range files {
if !file.IsDir() && strings.HasSuffix(file.Name(), ".json") {
fp := filepath.Join(dir, file.Name())
data, err := os.ReadFile(fp)
if err != nil {
logger.Error("failed to read scenario file", "error", err, "fp", fp)
continue // Skip this file
}
scenario := &models.RPScenario{}
if err := json.Unmarshal(data, &scenario); err != nil {
logger.Error("failed to unmarshal scenario file", "error", err, "fp", fp)
continue // Skip this file
}
scenarios = append(scenarios, scenario)
}
}
return scenarios, nil
}
func chooseParser(apiURL string) RespParser {
if apiURL == cfg.RemoteAPI {
return NewOpenRouterParser(logger)
}
return NewLCPRespParser(logger)
}
func init() {
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
AddSource: true,
}))
cfg = config.LoadConfigOrDefault("config.toml")
}
func main() {
mode := flag.String("mode", "questions", "Benchmark mode: 'questions' or 'rp'")
flag.Parse()
switch *mode {
case "questions":
questions, err := loadQuestions(cfg.QuestionsPath)
if err != nil {
panic(err)
}
answers, err := runBench(questions)
if err != nil {
panic(err)
}
data, err := json.MarshalIndent(answers, " ", " ")
if err != nil {
panic(err)
}
if err := os.WriteFile(filepath.Join(cfg.OutDir, "results.json"), data, 0666); err != nil {
panic(err)
}
case "rp":
scenarios, err := loadRPScenarios(cfg.RPScenariosDir)
if err != nil {
panic(err)
}
for _, scenario := range scenarios {
logger.Info("running scenario", "name", scenario.Name)
conversation, err := runRPBench(scenario)
if err != nil {
logger.Error("failed to run rp bench", "error", err, "scenario", scenario.Name)
continue
}
if err := saveRPResult(scenario.Name, conversation); err != nil {
logger.Error("failed to save rp result", "error", err, "scenario", scenario.Name)
}
}
default:
fmt.Println("Unknown mode, please use 'questions' or 'rp'")
}
}
func saveRPResult(name string, conversation []models.RPMessage) error {
date := time.Now().Format("2006-01-02")
fn := fmt.Sprintf("%s_%s.txt", name, date)
fp := filepath.Join(cfg.OutDir, fn)
var b strings.Builder
for _, msg := range conversation {
b.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Author, msg.Content))
}
return os.WriteFile(fp, []byte(b.String()), 0666)
}
func runRPBench(scenario *models.RPScenario) ([]models.RPMessage, error) {
conversation := []models.RPMessage{}
if scenario.FirstMsg != "" {
conversation = append(conversation, models.RPMessage{Author: "Scenario", Content: scenario.FirstMsg})
}
for i := 0; i < cfg.MaxTurns; i++ {
// User's turn (local model)
userPrompt := buildRPPrompt(scenario.UserSysPrompt, conversation, scenario.UserName)
userResp, err := callLLM(userPrompt, cfg.LocalAPI)
if err != nil {
return nil, fmt.Errorf("user turn failed: %w", err)
}
parser := chooseParser(cfg.LocalAPI)
userText, err := parser.ParseBytes(userResp)
if err != nil {
return nil, fmt.Errorf("user turn failed parsing response: %w", err)
}
conversation = append(conversation, models.RPMessage{Author: scenario.UserName, Content: userText})
// Character's turn (remote model)
charPrompt := buildRPPrompt(scenario.CharSysPrompt, conversation, scenario.CharName)
charResp, err := callLLM(charPrompt, cfg.RemoteAPI)
if err != nil {
return nil, fmt.Errorf("char turn failed: %w", err)
}
parser = chooseParser(cfg.RemoteAPI)
charText, err := parser.ParseBytes(charResp)
if err != nil {
return nil, fmt.Errorf("char turn failed parsing response: %w", err)
}
conversation = append(conversation, models.RPMessage{Author: scenario.CharName, Content: charText})
}
return conversation, nil
}
func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName string) string {
var b strings.Builder
b.WriteString(sysPrompt)
b.WriteString("\n\n")
for _, msg := range history {
b.WriteString(fmt.Sprintf("%s: %s\n", msg.Author, msg.Content))
}
b.WriteString(fmt.Sprintf("%s:", characterName))
return b.String()
}
func runBench(questions []models.Question) ([]models.Answer, error) {
answers := []models.Answer{}
for _, q := range questions {
resp, err := callLLM(buildPrompt(q.Question), cfg.RemoteAPI) // Assuming remote for questions
if err != nil {
logger.Error("failed to call llm", "error", err)
continue
}
parser := chooseParser(cfg.RemoteAPI)
respText, err := parser.ParseBytes(resp)
if err != nil {
logger.Error("failed to parse llm response", "error", err)
continue
}
a := models.Answer{Q: q, Answer: respText, Model: "default"} // model name from resp?
answers = append(answers, a)
}
return answers, nil
}
func buildPrompt(q string) string {
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
}
func callLLM(prompt string, apiURL string) ([]byte, error) {
method := "POST"
parser := chooseParser(apiURL)
payloadReader := parser.MakePayload(prompt)
var payloadBytes []byte
if payloadReader != nil {
var err error
payloadBytes, err = io.ReadAll(payloadReader)
if err != nil {
return nil, fmt.Errorf("failed to read payload: %w", err)
}
}
maxRetries := 6
baseDelay := 2 * time.Second
for attempt := 0; attempt < maxRetries; attempt++ {
req, err := http.NewRequest(method, apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
if apiURL == cfg.RemoteAPI {
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
}
resp, err := httpClient.Do(req)
if err != nil {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
}
logger.Error("http request failed; will retry", "error", err, "url", apiURL, "attempt", attempt)
time.Sleep(baseDelay * time.Duration(1<<attempt))
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode >= 400 {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d, body: %s", maxRetries, resp.StatusCode, string(body))
}
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt, "body", string(body))
time.Sleep(baseDelay * time.Duration(1<<attempt))
continue
}
logger.Debug("llm resp", "body", string(body), "url", apiURL, "attempt", attempt)
return body, nil
}
return nil, errors.New("exceeded max retries")
}