package main import ( "bytes" "encoding/json" "errors" "fmt" "flag" "io" "log/slog" "net/http" "os" "path/filepath" "strings" "time" "grailbench/config" "grailbench/models" ) var ( logger *slog.Logger cfg *config.Config httpClient = &http.Client{} ) func loadQuestions(fp string) ([]models.Question, error) { data, err := os.ReadFile(fp) if err != nil { logger.Error("failed to read file", "error", err, "fp", fp) return nil, err } resp := []models.Question{} if err := json.Unmarshal(data, &resp); err != nil { logger.Error("failed to unmarshal file", "error", err, "fp", fp) return nil, err } return resp, nil } func loadRPScenarios(dir string) ([]*models.RPScenario, error) { scenarios := []*models.RPScenario{} files, err := os.ReadDir(dir) if err != nil { return nil, err } for _, file := range files { if !file.IsDir() && strings.HasSuffix(file.Name(), ".json") { fp := filepath.Join(dir, file.Name()) data, err := os.ReadFile(fp) if err != nil { logger.Error("failed to read scenario file", "error", err, "fp", fp) continue // Skip this file } scenario := &models.RPScenario{} if err := json.Unmarshal(data, &scenario); err != nil { logger.Error("failed to unmarshal scenario file", "error", err, "fp", fp) continue // Skip this file } scenarios = append(scenarios, scenario) } } return scenarios, nil } func chooseParser(apiURL string) RespParser { if apiURL == cfg.RemoteAPI { return NewOpenRouterParser(logger) } return NewLCPRespParser(logger) } func init() { logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelDebug, AddSource: true, })) cfg = config.LoadConfigOrDefault("config.toml") } func main() { mode := flag.String("mode", "questions", "Benchmark mode: 'questions' or 'rp'") flag.Parse() switch *mode { case "questions": questions, err := loadQuestions(cfg.QuestionsPath) if err != nil { panic(err) } answers, err := runBench(questions) if err != nil { panic(err) } data, err := json.MarshalIndent(answers, " ", " ") if err != nil { panic(err) } if err := os.WriteFile(filepath.Join(cfg.OutDir, "results.json"), data, 0666); err != nil { panic(err) } case "rp": scenarios, err := loadRPScenarios(cfg.RPScenariosDir) if err != nil { panic(err) } for _, scenario := range scenarios { logger.Info("running scenario", "name", scenario.Name) conversation, err := runRPBench(scenario) if err != nil { logger.Error("failed to run rp bench", "error", err, "scenario", scenario.Name) continue } if err := saveRPResult(scenario.Name, conversation); err != nil { logger.Error("failed to save rp result", "error", err, "scenario", scenario.Name) } } default: fmt.Println("Unknown mode, please use 'questions' or 'rp'") } } func saveRPResult(name string, conversation []models.RPMessage) error { date := time.Now().Format("2006-01-02") fn := fmt.Sprintf("%s_%s.txt", name, date) fp := filepath.Join(cfg.OutDir, fn) var b strings.Builder for _, msg := range conversation { b.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Author, msg.Content)) } return os.WriteFile(fp, []byte(b.String()), 0666) } func runRPBench(scenario *models.RPScenario) ([]models.RPMessage, error) { conversation := []models.RPMessage{} if scenario.FirstMsg != "" { conversation = append(conversation, models.RPMessage{Author: "Scenario", Content: scenario.FirstMsg}) } for i := 0; i < cfg.MaxTurns; i++ { // User's turn (local model) userPrompt := buildRPPrompt(scenario.UserSysPrompt, conversation, scenario.UserName) userResp, err := callLLM(userPrompt, cfg.LocalAPI) if err != nil { return nil, fmt.Errorf("user turn failed: %w", err) } parser := chooseParser(cfg.LocalAPI) userText, err := parser.ParseBytes(userResp) if err != nil { return nil, fmt.Errorf("user turn failed parsing response: %w", err) } conversation = append(conversation, models.RPMessage{Author: scenario.UserName, Content: userText}) // Character's turn (remote model) charPrompt := buildRPPrompt(scenario.CharSysPrompt, conversation, scenario.CharName) charResp, err := callLLM(charPrompt, cfg.RemoteAPI) if err != nil { return nil, fmt.Errorf("char turn failed: %w", err) } parser = chooseParser(cfg.RemoteAPI) charText, err := parser.ParseBytes(charResp) if err != nil { return nil, fmt.Errorf("char turn failed parsing response: %w", err) } conversation = append(conversation, models.RPMessage{Author: scenario.CharName, Content: charText}) } return conversation, nil } func buildRPPrompt(sysPrompt string, history []models.RPMessage, characterName string) string { var b strings.Builder b.WriteString(sysPrompt) b.WriteString("\n\n") for _, msg := range history { b.WriteString(fmt.Sprintf("%s: %s\n", msg.Author, msg.Content)) } b.WriteString(fmt.Sprintf("%s:", characterName)) return b.String() } func processLLMResponse(respText string) (string, models.ToolCallInfo) { // Check if the response indicates tool usage var toolCall models.ToolCallInfo if strings.HasPrefix(respText, "[TOOL_CALL:") && strings.HasSuffix(respText, "]") { // Extract tool name from the marker toolName := strings.TrimPrefix(strings.TrimSuffix(respText, "]"), "[TOOL_CALL:") toolCall = models.ToolCallInfo{ Name: toolName, // Arguments would need to be parsed from the actual response } // Remove the marker from the response text respText = fmt.Sprintf("Used tool: %s", toolName) } return respText, toolCall } func runBench(questions []models.Question) ([]models.Answer, error) { answers := []models.Answer{} for _, q := range questions { resp, err := callLLM(buildPrompt(q.Question), cfg.RemoteAPI) // Assuming remote for questions if err != nil { logger.Error("failed to call llm", "error", err) continue } parser := chooseParser(cfg.RemoteAPI) respText, err := parser.ParseBytes(resp) if err != nil { logger.Error("failed to parse llm response", "error", err) continue } // Process the response to detect tool usage respText, toolCall := processLLMResponse(respText) a := models.Answer{ Q: q, Answer: respText, Model: "default", // model name from resp? ToolCall: toolCall, } answers = append(answers, a) } return answers, nil } func buildPrompt(q string) string { return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q) } func callLLM(prompt string, apiURL string) ([]byte, error) { method := "POST" parser := chooseParser(apiURL) payloadReader := parser.MakePayload(prompt) var payloadBytes []byte if payloadReader != nil { var err error payloadBytes, err = io.ReadAll(payloadReader) if err != nil { return nil, fmt.Errorf("failed to read payload: %w", err) } } maxRetries := 6 baseDelay := 2 * time.Second for attempt := 0; attempt < maxRetries; attempt++ { req, err := http.NewRequest(method, apiURL, bytes.NewReader(payloadBytes)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") if apiURL == cfg.RemoteAPI { req.Header.Add("Authorization", "Bearer "+cfg.APIToken) } resp, err := httpClient.Do(req) if err != nil { if attempt == maxRetries-1 { return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err) } logger.Error("http request failed; will retry", "error", err, "url", apiURL, "attempt", attempt) time.Sleep(baseDelay * time.Duration(1<= 400 { if attempt == maxRetries-1 { return nil, fmt.Errorf("LLM call failed after %d retries, got status %d, body: %s", maxRetries, resp.StatusCode, string(body)) } logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt, "body", string(body)) time.Sleep(baseDelay * time.Duration(1<