169 lines
4.5 KiB
Go
169 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"grailbench/config"
|
|
)
|
|
|
|
var (
|
|
logger *slog.Logger
|
|
cfg *config.Config
|
|
httpClient = &http.Client{}
|
|
currentModel = ""
|
|
)
|
|
|
|
type Question struct {
|
|
ID string `json:"id"`
|
|
Topic string `json:"topic"`
|
|
Question string `json:"question"`
|
|
}
|
|
|
|
type Answer struct {
|
|
Q Question
|
|
Answer string `json:"answer"`
|
|
Model string `json:"model"`
|
|
// resp time?
|
|
}
|
|
|
|
func loadQuestions(fp string) ([]Question, error) {
|
|
data, err := os.ReadFile(fp)
|
|
if err != nil {
|
|
logger.Error("failed to read file", "error", err, "fp", fp)
|
|
return nil, err
|
|
}
|
|
resp := []Question{}
|
|
if err := json.Unmarshal(data, &resp); err != nil {
|
|
logger.Error("failed to unmarshal file", "error", err, "fp", fp)
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func chooseParser() RespParser {
|
|
return NewLCPRespParser(logger)
|
|
}
|
|
|
|
func init() {
|
|
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
|
|
Level: slog.LevelDebug,
|
|
AddSource: true,
|
|
}))
|
|
cfg = config.LoadConfigOrDefault("config.toml")
|
|
}
|
|
|
|
func main() {
|
|
questions, err := loadQuestions(cfg.QuestionsPath)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
answers, err := runBench(questions)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
data, err := json.MarshalIndent(answers, " ", " ")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func runBench(questions []Question) ([]Answer, error) {
|
|
answers := []Answer{}
|
|
for _, q := range questions {
|
|
resp, err := callLLM(buildPrompt(q.Question))
|
|
if err != nil {
|
|
// log
|
|
panic(err)
|
|
}
|
|
parser := chooseParser()
|
|
respText, err := parser.ParseBytes(resp)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
a := Answer{Q: q, Answer: respText, Model: currentModel}
|
|
answers = append(answers, a)
|
|
}
|
|
return answers, nil
|
|
}
|
|
|
|
// openai vs completion
|
|
func buildPrompt(q string) string {
|
|
// sure injection?
|
|
// completion
|
|
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
|
|
}
|
|
|
|
func callLLM(prompt string) ([]byte, error) {
|
|
method := "POST"
|
|
// Generate the payload once as bytes
|
|
parser := chooseParser()
|
|
payloadReader := parser.MakePayload(prompt)
|
|
client := &http.Client{}
|
|
maxRetries := 6
|
|
baseDelay := 2 // seconds
|
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
// Create a new request for the attempt
|
|
req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader)
|
|
if err != nil {
|
|
if attempt == maxRetries-1 {
|
|
return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err)
|
|
}
|
|
logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
|
time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1))
|
|
continue
|
|
}
|
|
req.Header.Add("Content-Type", "application/json")
|
|
req.Header.Add("Accept", "application/json")
|
|
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
if attempt == maxRetries-1 {
|
|
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
|
|
}
|
|
logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
|
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
|
time.Sleep(delay)
|
|
continue
|
|
}
|
|
body, err := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
if attempt == maxRetries-1 {
|
|
return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err)
|
|
}
|
|
logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
|
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
|
time.Sleep(delay)
|
|
continue
|
|
}
|
|
// Check status code
|
|
if resp.StatusCode >= 400 && resp.StatusCode < 600 {
|
|
if attempt == maxRetries-1 {
|
|
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode)
|
|
}
|
|
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt)
|
|
delay := time.Duration((baseDelay * (1 << attempt))) * time.Second
|
|
time.Sleep(delay)
|
|
continue
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
// For non-retriable errors, return immediately
|
|
return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body))
|
|
}
|
|
// Success
|
|
logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt)
|
|
return body, nil
|
|
}
|
|
return nil, errors.New("unknown error in retry loop")
|
|
}
|