init
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
data
|
||||||
|
grailbench
|
||||||
|
*.json
|
||||||
|
config.toml
|
29
config/config.go
Normal file
29
config/config.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/BurntSushi/toml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
CurrentAPI string `toml:"CurrentAPI"`
|
||||||
|
APIToken string `toml:"APIToken"`
|
||||||
|
QuestionsPath string `toml:"QuestionsPath"`
|
||||||
|
OutPath string `toml:"OutPath"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfigOrDefault(fn string) *Config {
|
||||||
|
if fn == "" {
|
||||||
|
fn = "config.toml"
|
||||||
|
}
|
||||||
|
config := &Config{}
|
||||||
|
_, err := toml.DecodeFile(fn, &config)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("failed to read config from file, loading default", "error", err)
|
||||||
|
config.CurrentAPI = "http://localhost:8080/completion"
|
||||||
|
config.QuestionsPath = "data/questions.json"
|
||||||
|
config.OutPath = "data/out.json"
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
5
go.mod
Normal file
5
go.mod
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
module grailbench
|
||||||
|
|
||||||
|
go 1.24.5
|
||||||
|
|
||||||
|
require github.com/BurntSushi/toml v1.5.0
|
2
go.sum
Normal file
2
go.sum
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||||
|
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
168
main.go
Normal file
168
main.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"grailbench/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
logger *slog.Logger
|
||||||
|
cfg *config.Config
|
||||||
|
httpClient = &http.Client{}
|
||||||
|
currentModel = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
type Question struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
Question string `json:"question"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Answer struct {
|
||||||
|
Q Question
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
// resp time?
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadQuestions(fp string) ([]Question, error) {
|
||||||
|
data, err := os.ReadFile(fp)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to read file", "error", err, "fp", fp)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp := []Question{}
|
||||||
|
if err := json.Unmarshal(data, &resp); err != nil {
|
||||||
|
logger.Error("failed to unmarshal file", "error", err, "fp", fp)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func chooseParser() RespParser {
|
||||||
|
return NewLCPRespParser(logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
|
||||||
|
Level: slog.LevelDebug,
|
||||||
|
AddSource: true,
|
||||||
|
}))
|
||||||
|
cfg = config.LoadConfigOrDefault("config.toml")
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
questions, err := loadQuestions(cfg.QuestionsPath)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
answers, err := runBench(questions)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
data, err := json.MarshalIndent(answers, " ", " ")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runBench(questions []Question) ([]Answer, error) {
|
||||||
|
answers := []Answer{}
|
||||||
|
for _, q := range questions {
|
||||||
|
resp, err := callLLM(buildPrompt(q.Question))
|
||||||
|
if err != nil {
|
||||||
|
// log
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
parser := chooseParser()
|
||||||
|
respText, err := parser.ParseBytes(resp)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
a := Answer{Q: q, Answer: respText, Model: currentModel}
|
||||||
|
answers = append(answers, a)
|
||||||
|
}
|
||||||
|
return answers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openai vs completion
|
||||||
|
func buildPrompt(q string) string {
|
||||||
|
// sure injection?
|
||||||
|
// completion
|
||||||
|
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func callLLM(prompt string) ([]byte, error) {
|
||||||
|
method := "POST"
|
||||||
|
// Generate the payload once as bytes
|
||||||
|
parser := chooseParser()
|
||||||
|
payloadReader := parser.MakePayload(prompt)
|
||||||
|
client := &http.Client{}
|
||||||
|
maxRetries := 6
|
||||||
|
baseDelay := 2 // seconds
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
// Create a new request for the attempt
|
||||||
|
req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader)
|
||||||
|
if err != nil {
|
||||||
|
if attempt == maxRetries-1 {
|
||||||
|
return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err)
|
||||||
|
}
|
||||||
|
logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
||||||
|
time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
req.Header.Add("Accept", "application/json")
|
||||||
|
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if attempt == maxRetries-1 {
|
||||||
|
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
|
||||||
|
}
|
||||||
|
logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
||||||
|
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
||||||
|
time.Sleep(delay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
if attempt == maxRetries-1 {
|
||||||
|
return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err)
|
||||||
|
}
|
||||||
|
logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
||||||
|
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
||||||
|
time.Sleep(delay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Check status code
|
||||||
|
if resp.StatusCode >= 400 && resp.StatusCode < 600 {
|
||||||
|
if attempt == maxRetries-1 {
|
||||||
|
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode)
|
||||||
|
}
|
||||||
|
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt)
|
||||||
|
delay := time.Duration((baseDelay * (1 << attempt))) * time.Second
|
||||||
|
time.Sleep(delay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
// For non-retriable errors, return immediately
|
||||||
|
return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
// Success
|
||||||
|
logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt)
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("unknown error in retry loop")
|
||||||
|
}
|
56
models.go
Normal file
56
models.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
type OpenRouterResp struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
Choices []struct {
|
||||||
|
Logprobs any `json:"logprobs"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
NativeFinishReason string `json:"native_finish_reason"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Refusal any `json:"refusal"`
|
||||||
|
Reasoning any `json:"reasoning"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DSResp struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Choices []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LLMResp struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Tokens []any `json:"tokens"`
|
||||||
|
IDSlot int `json:"id_slot"`
|
||||||
|
Stop bool `json:"stop"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
TokensPredicted int `json:"tokens_predicted"`
|
||||||
|
TokensEvaluated int `json:"tokens_evaluated"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
HasNewLine bool `json:"has_new_line"`
|
||||||
|
Truncated bool `json:"truncated"`
|
||||||
|
StopType string `json:"stop_type"`
|
||||||
|
StoppingWord string `json:"stopping_word"`
|
||||||
|
TokensCached int `json:"tokens_cached"`
|
||||||
|
}
|
146
parser.go
Normal file
146
parser.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RespParser interface {
|
||||||
|
ParseBytes(body []byte) (string, error)
|
||||||
|
MakePayload(prompt string) io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeepSeekParser: deepseek implementation of RespParser
|
||||||
|
type deepSeekParser struct {
|
||||||
|
log *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser {
|
||||||
|
return &deepSeekParser{log: log}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
||||||
|
// parsing logic here
|
||||||
|
dsResp := DSResp{}
|
||||||
|
if err := json.Unmarshal(body, &dsResp); err != nil {
|
||||||
|
p.log.Error("failed to unmarshall", "error", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(dsResp.Choices) == 0 {
|
||||||
|
p.log.Error("empty choices", "dsResp", dsResp)
|
||||||
|
err := errors.New("empty choices in dsResp")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
text := dsResp.Choices[0].Text
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
||||||
|
return strings.NewReader(fmt.Sprintf(`{
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"prompt": "%s",
|
||||||
|
"echo": false,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"logprobs": 0,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"stop": null,
|
||||||
|
"stream": false,
|
||||||
|
"stream_options": null,
|
||||||
|
"suffix": null,
|
||||||
|
"temperature": 1,
|
||||||
|
"top_p": 1
|
||||||
|
}`, prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// llama.cpp implementation of RespParser
|
||||||
|
type lcpRespParser struct {
|
||||||
|
log *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLCPRespParser(log *slog.Logger) *lcpRespParser {
|
||||||
|
return &lcpRespParser{log: log}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
||||||
|
// parsing logic here
|
||||||
|
resp := LLMResp{}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
p.log.Error("failed to unmarshal", "error", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return resp.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
||||||
|
return strings.NewReader(fmt.Sprintf(`{
|
||||||
|
"model": "local-model",
|
||||||
|
"prompt": "%s",
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"stop": ["Q:\n", "A:\n"],
|
||||||
|
"stream": false,
|
||||||
|
"temperature": 0.4,
|
||||||
|
"top_p": 1
|
||||||
|
}`, prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
type openRouterParser struct {
|
||||||
|
log *slog.Logger
|
||||||
|
modelIndex uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
||||||
|
return &openRouterParser{
|
||||||
|
log: log,
|
||||||
|
modelIndex: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
||||||
|
// parsing logic here
|
||||||
|
resp := OpenRouterResp{}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
p.log.Error("failed to unmarshal", "error", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(resp.Choices) == 0 {
|
||||||
|
p.log.Error("empty choices", "resp", resp)
|
||||||
|
err := errors.New("empty choices in resp")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
text := resp.Choices[0].Message.Content
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||||
|
// Models to rotate through
|
||||||
|
models := []string{
|
||||||
|
"google/gemini-2.0-flash-exp:free",
|
||||||
|
"deepseek/deepseek-chat-v3-0324:free",
|
||||||
|
"mistralai/mistral-small-3.2-24b-instruct:free",
|
||||||
|
"qwen/qwen3-14b:free",
|
||||||
|
"deepseek/deepseek-r1:free",
|
||||||
|
"google/gemma-3-27b-it:free",
|
||||||
|
"meta-llama/llama-3.3-70b-instruct:free",
|
||||||
|
}
|
||||||
|
// Get next model index using atomic addition for thread safety
|
||||||
|
p.modelIndex++
|
||||||
|
model := models[int(p.modelIndex)%len(models)]
|
||||||
|
strPayload := fmt.Sprintf(`{
|
||||||
|
"model": "%s",
|
||||||
|
"max_tokens": 300,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "%s"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`, model, prompt)
|
||||||
|
p.log.Debug("made openrouter payload", "model", model, "payload", strPayload)
|
||||||
|
return strings.NewReader(strPayload)
|
||||||
|
}
|
Reference in New Issue
Block a user