This commit is contained in:
Grail Finder
2025-07-31 10:46:07 +03:00
commit 757e948659
7 changed files with 410 additions and 0 deletions

168
main.go Normal file
View File

@@ -0,0 +1,168 @@
package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"time"
"grailbench/config"
)
var (
logger *slog.Logger
cfg *config.Config
httpClient = &http.Client{}
currentModel = ""
)
type Question struct {
ID string `json:"id"`
Topic string `json:"topic"`
Question string `json:"question"`
}
type Answer struct {
Q Question
Answer string `json:"answer"`
Model string `json:"model"`
// resp time?
}
func loadQuestions(fp string) ([]Question, error) {
data, err := os.ReadFile(fp)
if err != nil {
logger.Error("failed to read file", "error", err, "fp", fp)
return nil, err
}
resp := []Question{}
if err := json.Unmarshal(data, &resp); err != nil {
logger.Error("failed to unmarshal file", "error", err, "fp", fp)
return nil, err
}
return resp, nil
}
func chooseParser() RespParser {
return NewLCPRespParser(logger)
}
func init() {
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
AddSource: true,
}))
cfg = config.LoadConfigOrDefault("config.toml")
}
func main() {
questions, err := loadQuestions(cfg.QuestionsPath)
if err != nil {
panic(err)
}
answers, err := runBench(questions)
if err != nil {
panic(err)
}
data, err := json.MarshalIndent(answers, " ", " ")
if err != nil {
panic(err)
}
if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil {
panic(err)
}
}
func runBench(questions []Question) ([]Answer, error) {
answers := []Answer{}
for _, q := range questions {
resp, err := callLLM(buildPrompt(q.Question))
if err != nil {
// log
panic(err)
}
parser := chooseParser()
respText, err := parser.ParseBytes(resp)
if err != nil {
panic(err)
}
a := Answer{Q: q, Answer: respText, Model: currentModel}
answers = append(answers, a)
}
return answers, nil
}
// openai vs completion
func buildPrompt(q string) string {
// sure injection?
// completion
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
}
func callLLM(prompt string) ([]byte, error) {
method := "POST"
// Generate the payload once as bytes
parser := chooseParser()
payloadReader := parser.MakePayload(prompt)
client := &http.Client{}
maxRetries := 6
baseDelay := 2 // seconds
for attempt := 0; attempt < maxRetries; attempt++ {
// Create a new request for the attempt
req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader)
if err != nil {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err)
}
logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1))
continue
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
resp, err := client.Do(req)
if err != nil {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
}
logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
time.Sleep(delay)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err)
}
logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
time.Sleep(delay)
continue
}
// Check status code
if resp.StatusCode >= 400 && resp.StatusCode < 600 {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode)
}
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt)
delay := time.Duration((baseDelay * (1 << attempt))) * time.Second
time.Sleep(delay)
continue
}
if resp.StatusCode != http.StatusOK {
// For non-retriable errors, return immediately
return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body))
}
// Success
logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt)
return body, nil
}
return nil, errors.New("unknown error in retry loop")
}