init
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
data
|
||||
grailbench
|
||||
*.json
|
||||
config.toml
|
29
config/config.go
Normal file
29
config/config.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
CurrentAPI string `toml:"CurrentAPI"`
|
||||
APIToken string `toml:"APIToken"`
|
||||
QuestionsPath string `toml:"QuestionsPath"`
|
||||
OutPath string `toml:"OutPath"`
|
||||
}
|
||||
|
||||
func LoadConfigOrDefault(fn string) *Config {
|
||||
if fn == "" {
|
||||
fn = "config.toml"
|
||||
}
|
||||
config := &Config{}
|
||||
_, err := toml.DecodeFile(fn, &config)
|
||||
if err != nil {
|
||||
fmt.Println("failed to read config from file, loading default", "error", err)
|
||||
config.CurrentAPI = "http://localhost:8080/completion"
|
||||
config.QuestionsPath = "data/questions.json"
|
||||
config.OutPath = "data/out.json"
|
||||
}
|
||||
return config
|
||||
}
|
5
go.mod
Normal file
5
go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module grailbench
|
||||
|
||||
go 1.24.5
|
||||
|
||||
require github.com/BurntSushi/toml v1.5.0
|
2
go.sum
Normal file
2
go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
168
main.go
Normal file
168
main.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"grailbench/config"
|
||||
)
|
||||
|
||||
var (
|
||||
logger *slog.Logger
|
||||
cfg *config.Config
|
||||
httpClient = &http.Client{}
|
||||
currentModel = ""
|
||||
)
|
||||
|
||||
type Question struct {
|
||||
ID string `json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
Question string `json:"question"`
|
||||
}
|
||||
|
||||
type Answer struct {
|
||||
Q Question
|
||||
Answer string `json:"answer"`
|
||||
Model string `json:"model"`
|
||||
// resp time?
|
||||
}
|
||||
|
||||
func loadQuestions(fp string) ([]Question, error) {
|
||||
data, err := os.ReadFile(fp)
|
||||
if err != nil {
|
||||
logger.Error("failed to read file", "error", err, "fp", fp)
|
||||
return nil, err
|
||||
}
|
||||
resp := []Question{}
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
logger.Error("failed to unmarshal file", "error", err, "fp", fp)
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func chooseParser() RespParser {
|
||||
return NewLCPRespParser(logger)
|
||||
}
|
||||
|
||||
func init() {
|
||||
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
AddSource: true,
|
||||
}))
|
||||
cfg = config.LoadConfigOrDefault("config.toml")
|
||||
}
|
||||
|
||||
func main() {
|
||||
questions, err := loadQuestions(cfg.QuestionsPath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
answers, err := runBench(questions)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
data, err := json.MarshalIndent(answers, " ", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func runBench(questions []Question) ([]Answer, error) {
|
||||
answers := []Answer{}
|
||||
for _, q := range questions {
|
||||
resp, err := callLLM(buildPrompt(q.Question))
|
||||
if err != nil {
|
||||
// log
|
||||
panic(err)
|
||||
}
|
||||
parser := chooseParser()
|
||||
respText, err := parser.ParseBytes(resp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
a := Answer{Q: q, Answer: respText, Model: currentModel}
|
||||
answers = append(answers, a)
|
||||
}
|
||||
return answers, nil
|
||||
}
|
||||
|
||||
// openai vs completion
|
||||
func buildPrompt(q string) string {
|
||||
// sure injection?
|
||||
// completion
|
||||
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
|
||||
}
|
||||
|
||||
func callLLM(prompt string) ([]byte, error) {
|
||||
method := "POST"
|
||||
// Generate the payload once as bytes
|
||||
parser := chooseParser()
|
||||
payloadReader := parser.MakePayload(prompt)
|
||||
client := &http.Client{}
|
||||
maxRetries := 6
|
||||
baseDelay := 2 // seconds
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
// Create a new request for the attempt
|
||||
req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader)
|
||||
if err != nil {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err)
|
||||
}
|
||||
logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
||||
time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1))
|
||||
continue
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
|
||||
}
|
||||
logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
||||
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err)
|
||||
}
|
||||
logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
|
||||
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
// Check status code
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 600 {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode)
|
||||
}
|
||||
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt)
|
||||
delay := time.Duration((baseDelay * (1 << attempt))) * time.Second
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// For non-retriable errors, return immediately
|
||||
return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
// Success
|
||||
logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt)
|
||||
return body, nil
|
||||
}
|
||||
return nil, errors.New("unknown error in retry loop")
|
||||
}
|
56
models.go
Normal file
56
models.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package main
|
||||
|
||||
type OpenRouterResp struct {
|
||||
ID string `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Choices []struct {
|
||||
Logprobs any `json:"logprobs"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
NativeFinishReason string `json:"native_finish_reason"`
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal any `json:"refusal"`
|
||||
Reasoning any `json:"reasoning"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type DSResp struct {
|
||||
ID string `json:"id"`
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
type LLMResp struct {
|
||||
Index int `json:"index"`
|
||||
Content string `json:"content"`
|
||||
Tokens []any `json:"tokens"`
|
||||
IDSlot int `json:"id_slot"`
|
||||
Stop bool `json:"stop"`
|
||||
Model string `json:"model"`
|
||||
TokensPredicted int `json:"tokens_predicted"`
|
||||
TokensEvaluated int `json:"tokens_evaluated"`
|
||||
Prompt string `json:"prompt"`
|
||||
HasNewLine bool `json:"has_new_line"`
|
||||
Truncated bool `json:"truncated"`
|
||||
StopType string `json:"stop_type"`
|
||||
StoppingWord string `json:"stopping_word"`
|
||||
TokensCached int `json:"tokens_cached"`
|
||||
}
|
146
parser.go
Normal file
146
parser.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RespParser interface {
|
||||
ParseBytes(body []byte) (string, error)
|
||||
MakePayload(prompt string) io.Reader
|
||||
}
|
||||
|
||||
// DeepSeekParser: deepseek implementation of RespParser
|
||||
type deepSeekParser struct {
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser {
|
||||
return &deepSeekParser{log: log}
|
||||
}
|
||||
|
||||
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
||||
// parsing logic here
|
||||
dsResp := DSResp{}
|
||||
if err := json.Unmarshal(body, &dsResp); err != nil {
|
||||
p.log.Error("failed to unmarshall", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if len(dsResp.Choices) == 0 {
|
||||
p.log.Error("empty choices", "dsResp", dsResp)
|
||||
err := errors.New("empty choices in dsResp")
|
||||
return "", err
|
||||
}
|
||||
text := dsResp.Choices[0].Text
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
||||
return strings.NewReader(fmt.Sprintf(`{
|
||||
"model": "deepseek-chat",
|
||||
"prompt": "%s",
|
||||
"echo": false,
|
||||
"frequency_penalty": 0,
|
||||
"logprobs": 0,
|
||||
"max_tokens": 1024,
|
||||
"presence_penalty": 0,
|
||||
"stop": null,
|
||||
"stream": false,
|
||||
"stream_options": null,
|
||||
"suffix": null,
|
||||
"temperature": 1,
|
||||
"top_p": 1
|
||||
}`, prompt))
|
||||
}
|
||||
|
||||
// llama.cpp implementation of RespParser
|
||||
type lcpRespParser struct {
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewLCPRespParser(log *slog.Logger) *lcpRespParser {
|
||||
return &lcpRespParser{log: log}
|
||||
}
|
||||
|
||||
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
||||
// parsing logic here
|
||||
resp := LLMResp{}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
p.log.Error("failed to unmarshal", "error", err)
|
||||
return "", err
|
||||
}
|
||||
return resp.Content, nil
|
||||
}
|
||||
|
||||
func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
||||
return strings.NewReader(fmt.Sprintf(`{
|
||||
"model": "local-model",
|
||||
"prompt": "%s",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 1024,
|
||||
"stop": ["Q:\n", "A:\n"],
|
||||
"stream": false,
|
||||
"temperature": 0.4,
|
||||
"top_p": 1
|
||||
}`, prompt))
|
||||
}
|
||||
|
||||
type openRouterParser struct {
|
||||
log *slog.Logger
|
||||
modelIndex uint32
|
||||
}
|
||||
|
||||
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
||||
return &openRouterParser{
|
||||
log: log,
|
||||
modelIndex: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
||||
// parsing logic here
|
||||
resp := OpenRouterResp{}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
p.log.Error("failed to unmarshal", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
p.log.Error("empty choices", "resp", resp)
|
||||
err := errors.New("empty choices in resp")
|
||||
return "", err
|
||||
}
|
||||
text := resp.Choices[0].Message.Content
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||
// Models to rotate through
|
||||
models := []string{
|
||||
"google/gemini-2.0-flash-exp:free",
|
||||
"deepseek/deepseek-chat-v3-0324:free",
|
||||
"mistralai/mistral-small-3.2-24b-instruct:free",
|
||||
"qwen/qwen3-14b:free",
|
||||
"deepseek/deepseek-r1:free",
|
||||
"google/gemma-3-27b-it:free",
|
||||
"meta-llama/llama-3.3-70b-instruct:free",
|
||||
}
|
||||
// Get next model index using atomic addition for thread safety
|
||||
p.modelIndex++
|
||||
model := models[int(p.modelIndex)%len(models)]
|
||||
strPayload := fmt.Sprintf(`{
|
||||
"model": "%s",
|
||||
"max_tokens": 300,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "%s"
|
||||
}
|
||||
]
|
||||
}`, model, prompt)
|
||||
p.log.Debug("made openrouter payload", "model", model, "payload", strPayload)
|
||||
return strings.NewReader(strPayload)
|
||||
}
|
Reference in New Issue
Block a user