This commit is contained in:
Grail Finder
2025-07-31 10:46:07 +03:00
commit 757e948659
7 changed files with 410 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
data
grailbench
*.json
config.toml

29
config/config.go Normal file
View File

@@ -0,0 +1,29 @@
package config
import (
"fmt"
"github.com/BurntSushi/toml"
)
type Config struct {
CurrentAPI string `toml:"CurrentAPI"`
APIToken string `toml:"APIToken"`
QuestionsPath string `toml:"QuestionsPath"`
OutPath string `toml:"OutPath"`
}
func LoadConfigOrDefault(fn string) *Config {
if fn == "" {
fn = "config.toml"
}
config := &Config{}
_, err := toml.DecodeFile(fn, &config)
if err != nil {
fmt.Println("failed to read config from file, loading default", "error", err)
config.CurrentAPI = "http://localhost:8080/completion"
config.QuestionsPath = "data/questions.json"
config.OutPath = "data/out.json"
}
return config
}

5
go.mod Normal file
View File

@@ -0,0 +1,5 @@
module grailbench
go 1.24.5
require github.com/BurntSushi/toml v1.5.0

2
go.sum Normal file
View File

@@ -0,0 +1,2 @@
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=

168
main.go Normal file
View File

@@ -0,0 +1,168 @@
package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"time"
"grailbench/config"
)
var (
logger *slog.Logger
cfg *config.Config
httpClient = &http.Client{}
currentModel = ""
)
type Question struct {
ID string `json:"id"`
Topic string `json:"topic"`
Question string `json:"question"`
}
type Answer struct {
Q Question
Answer string `json:"answer"`
Model string `json:"model"`
// resp time?
}
func loadQuestions(fp string) ([]Question, error) {
data, err := os.ReadFile(fp)
if err != nil {
logger.Error("failed to read file", "error", err, "fp", fp)
return nil, err
}
resp := []Question{}
if err := json.Unmarshal(data, &resp); err != nil {
logger.Error("failed to unmarshal file", "error", err, "fp", fp)
return nil, err
}
return resp, nil
}
func chooseParser() RespParser {
return NewLCPRespParser(logger)
}
func init() {
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
AddSource: true,
}))
cfg = config.LoadConfigOrDefault("config.toml")
}
func main() {
questions, err := loadQuestions(cfg.QuestionsPath)
if err != nil {
panic(err)
}
answers, err := runBench(questions)
if err != nil {
panic(err)
}
data, err := json.MarshalIndent(answers, " ", " ")
if err != nil {
panic(err)
}
if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil {
panic(err)
}
}
func runBench(questions []Question) ([]Answer, error) {
answers := []Answer{}
for _, q := range questions {
resp, err := callLLM(buildPrompt(q.Question))
if err != nil {
// log
panic(err)
}
parser := chooseParser()
respText, err := parser.ParseBytes(resp)
if err != nil {
panic(err)
}
a := Answer{Q: q, Answer: respText, Model: currentModel}
answers = append(answers, a)
}
return answers, nil
}
// openai vs completion
func buildPrompt(q string) string {
// sure injection?
// completion
return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q)
}
func callLLM(prompt string) ([]byte, error) {
method := "POST"
// Generate the payload once as bytes
parser := chooseParser()
payloadReader := parser.MakePayload(prompt)
client := &http.Client{}
maxRetries := 6
baseDelay := 2 // seconds
for attempt := 0; attempt < maxRetries; attempt++ {
// Create a new request for the attempt
req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader)
if err != nil {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err)
}
logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1))
continue
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
req.Header.Add("Authorization", "Bearer "+cfg.APIToken)
resp, err := client.Do(req)
if err != nil {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err)
}
logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
time.Sleep(delay)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err)
}
logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt)
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
time.Sleep(delay)
continue
}
// Check status code
if resp.StatusCode >= 400 && resp.StatusCode < 600 {
if attempt == maxRetries-1 {
return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode)
}
logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt)
delay := time.Duration((baseDelay * (1 << attempt))) * time.Second
time.Sleep(delay)
continue
}
if resp.StatusCode != http.StatusOK {
// For non-retriable errors, return immediately
return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body))
}
// Success
logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt)
return body, nil
}
return nil, errors.New("unknown error in retry loop")
}

56
models.go Normal file
View File

@@ -0,0 +1,56 @@
package main
type OpenRouterResp struct {
ID string `json:"id"`
Provider string `json:"provider"`
Model string `json:"model"`
Object string `json:"object"`
Created int `json:"created"`
Choices []struct {
Logprobs any `json:"logprobs"`
FinishReason string `json:"finish_reason"`
NativeFinishReason string `json:"native_finish_reason"`
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
Refusal any `json:"refusal"`
Reasoning any `json:"reasoning"`
} `json:"message"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type DSResp struct {
ID string `json:"id"`
Choices []struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Object string `json:"object"`
}
type LLMResp struct {
Index int `json:"index"`
Content string `json:"content"`
Tokens []any `json:"tokens"`
IDSlot int `json:"id_slot"`
Stop bool `json:"stop"`
Model string `json:"model"`
TokensPredicted int `json:"tokens_predicted"`
TokensEvaluated int `json:"tokens_evaluated"`
Prompt string `json:"prompt"`
HasNewLine bool `json:"has_new_line"`
Truncated bool `json:"truncated"`
StopType string `json:"stop_type"`
StoppingWord string `json:"stopping_word"`
TokensCached int `json:"tokens_cached"`
}

146
parser.go Normal file
View File

@@ -0,0 +1,146 @@
package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"strings"
)
type RespParser interface {
ParseBytes(body []byte) (string, error)
MakePayload(prompt string) io.Reader
}
// DeepSeekParser: deepseek implementation of RespParser
type deepSeekParser struct {
log *slog.Logger
}
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser {
return &deepSeekParser{log: log}
}
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
// parsing logic here
dsResp := DSResp{}
if err := json.Unmarshal(body, &dsResp); err != nil {
p.log.Error("failed to unmarshall", "error", err)
return "", err
}
if len(dsResp.Choices) == 0 {
p.log.Error("empty choices", "dsResp", dsResp)
err := errors.New("empty choices in dsResp")
return "", err
}
text := dsResp.Choices[0].Text
return text, nil
}
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
return strings.NewReader(fmt.Sprintf(`{
"model": "deepseek-chat",
"prompt": "%s",
"echo": false,
"frequency_penalty": 0,
"logprobs": 0,
"max_tokens": 1024,
"presence_penalty": 0,
"stop": null,
"stream": false,
"stream_options": null,
"suffix": null,
"temperature": 1,
"top_p": 1
}`, prompt))
}
// llama.cpp implementation of RespParser
type lcpRespParser struct {
log *slog.Logger
}
func NewLCPRespParser(log *slog.Logger) *lcpRespParser {
return &lcpRespParser{log: log}
}
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
// parsing logic here
resp := LLMResp{}
if err := json.Unmarshal(body, &resp); err != nil {
p.log.Error("failed to unmarshal", "error", err)
return "", err
}
return resp.Content, nil
}
func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
return strings.NewReader(fmt.Sprintf(`{
"model": "local-model",
"prompt": "%s",
"frequency_penalty": 0,
"max_tokens": 1024,
"stop": ["Q:\n", "A:\n"],
"stream": false,
"temperature": 0.4,
"top_p": 1
}`, prompt))
}
type openRouterParser struct {
log *slog.Logger
modelIndex uint32
}
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
return &openRouterParser{
log: log,
modelIndex: 0,
}
}
func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
// parsing logic here
resp := OpenRouterResp{}
if err := json.Unmarshal(body, &resp); err != nil {
p.log.Error("failed to unmarshal", "error", err)
return "", err
}
if len(resp.Choices) == 0 {
p.log.Error("empty choices", "resp", resp)
err := errors.New("empty choices in resp")
return "", err
}
text := resp.Choices[0].Message.Content
return text, nil
}
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
// Models to rotate through
models := []string{
"google/gemini-2.0-flash-exp:free",
"deepseek/deepseek-chat-v3-0324:free",
"mistralai/mistral-small-3.2-24b-instruct:free",
"qwen/qwen3-14b:free",
"deepseek/deepseek-r1:free",
"google/gemma-3-27b-it:free",
"meta-llama/llama-3.3-70b-instruct:free",
}
// Get next model index using atomic addition for thread safety
p.modelIndex++
model := models[int(p.modelIndex)%len(models)]
strPayload := fmt.Sprintf(`{
"model": "%s",
"max_tokens": 300,
"messages": [
{
"role": "user",
"content": "%s"
}
]
}`, model, prompt)
p.log.Debug("made openrouter payload", "model", model, "payload", strPayload)
return strings.NewReader(strPayload)
}