From 757e948659c1344112edab315455be8185fdbf96 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Thu, 31 Jul 2025 10:46:07 +0300 Subject: [PATCH] init --- .gitignore | 4 ++ config/config.go | 29 ++++++++ go.mod | 5 ++ go.sum | 2 + main.go | 168 +++++++++++++++++++++++++++++++++++++++++++++++ models.go | 56 ++++++++++++++++ parser.go | 146 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 410 insertions(+) create mode 100644 .gitignore create mode 100644 config/config.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 models.go create mode 100644 parser.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f5a1ca9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +data +grailbench +*.json +config.toml diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..2887834 --- /dev/null +++ b/config/config.go @@ -0,0 +1,29 @@ +package config + +import ( + "fmt" + + "github.com/BurntSushi/toml" +) + +type Config struct { + CurrentAPI string `toml:"CurrentAPI"` + APIToken string `toml:"APIToken"` + QuestionsPath string `toml:"QuestionsPath"` + OutPath string `toml:"OutPath"` +} + +func LoadConfigOrDefault(fn string) *Config { + if fn == "" { + fn = "config.toml" + } + config := &Config{} + _, err := toml.DecodeFile(fn, &config) + if err != nil { + fmt.Println("failed to read config from file, loading default", "error", err) + config.CurrentAPI = "http://localhost:8080/completion" + config.QuestionsPath = "data/questions.json" + config.OutPath = "data/out.json" + } + return config +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0433f74 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module grailbench + +go 1.24.5 + +require github.com/BurntSushi/toml v1.5.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ff7fd09 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= diff --git a/main.go b/main.go new file mode 100644 index 0000000..552988d --- /dev/null +++ b/main.go @@ -0,0 +1,168 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "time" + + "grailbench/config" +) + +var ( + logger *slog.Logger + cfg *config.Config + httpClient = &http.Client{} + currentModel = "" +) + +type Question struct { + ID string `json:"id"` + Topic string `json:"topic"` + Question string `json:"question"` +} + +type Answer struct { + Q Question + Answer string `json:"answer"` + Model string `json:"model"` + // resp time? +} + +func loadQuestions(fp string) ([]Question, error) { + data, err := os.ReadFile(fp) + if err != nil { + logger.Error("failed to read file", "error", err, "fp", fp) + return nil, err + } + resp := []Question{} + if err := json.Unmarshal(data, &resp); err != nil { + logger.Error("failed to unmarshal file", "error", err, "fp", fp) + return nil, err + } + return resp, nil +} + +func chooseParser() RespParser { + return NewLCPRespParser(logger) +} + +func init() { + logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelDebug, + AddSource: true, + })) + cfg = config.LoadConfigOrDefault("config.toml") +} + +func main() { + questions, err := loadQuestions(cfg.QuestionsPath) + if err != nil { + panic(err) + } + answers, err := runBench(questions) + if err != nil { + panic(err) + } + data, err := json.MarshalIndent(answers, " ", " ") + if err != nil { + panic(err) + } + if err := os.WriteFile(cfg.OutPath, data, 0666); err != nil { + panic(err) + } +} + +func runBench(questions []Question) ([]Answer, error) { + answers := []Answer{} + for _, q := range questions { + resp, err := callLLM(buildPrompt(q.Question)) + if err != nil { + // log + panic(err) + } + parser := chooseParser() + respText, err := parser.ParseBytes(resp) + if err != nil { + panic(err) + } + a := Answer{Q: q, Answer: respText, Model: currentModel} + answers = append(answers, a) + } + return answers, nil +} + +// openai vs completion +func buildPrompt(q string) string { + // sure injection? + // completion + return fmt.Sprintf(`Q:\n%s\nA:\nSure,`, q) +} + +func callLLM(prompt string) ([]byte, error) { + method := "POST" + // Generate the payload once as bytes + parser := chooseParser() + payloadReader := parser.MakePayload(prompt) + client := &http.Client{} + maxRetries := 6 + baseDelay := 2 // seconds + for attempt := 0; attempt < maxRetries; attempt++ { + // Create a new request for the attempt + req, err := http.NewRequest(method, cfg.CurrentAPI, payloadReader) + if err != nil { + if attempt == maxRetries-1 { + return nil, fmt.Errorf("LLM call failed after %d retries on request creation: %w", maxRetries, err) + } + logger.Error("failed to make new request; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt) + time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1)) + continue + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Add("Authorization", "Bearer "+cfg.APIToken) + resp, err := client.Do(req) + if err != nil { + if attempt == maxRetries-1 { + return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err) + } + logger.Error("http request failed; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt) + delay := time.Duration(baseDelay*(attempt+1)) * time.Second + time.Sleep(delay) + continue + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + if attempt == maxRetries-1 { + return nil, fmt.Errorf("LLM call failed after %d retries on reading body: %w", maxRetries, err) + } + logger.Error("failed to read response body; will retry", "error", err, "url", cfg.CurrentAPI, "attempt", attempt) + delay := time.Duration(baseDelay*(attempt+1)) * time.Second + time.Sleep(delay) + continue + } + // Check status code + if resp.StatusCode >= 400 && resp.StatusCode < 600 { + if attempt == maxRetries-1 { + return nil, fmt.Errorf("LLM call failed after %d retries, got status %d", maxRetries, resp.StatusCode) + } + logger.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt) + delay := time.Duration((baseDelay * (1 << attempt))) * time.Second + time.Sleep(delay) + continue + } + if resp.StatusCode != http.StatusOK { + // For non-retriable errors, return immediately + return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body)) + } + // Success + logger.Debug("llm resp", "body", string(body), "url", cfg.CurrentAPI, "attempt", attempt) + return body, nil + } + return nil, errors.New("unknown error in retry loop") +} diff --git a/models.go b/models.go new file mode 100644 index 0000000..a8ee45c --- /dev/null +++ b/models.go @@ -0,0 +1,56 @@ +package main + +type OpenRouterResp struct { + ID string `json:"id"` + Provider string `json:"provider"` + Model string `json:"model"` + Object string `json:"object"` + Created int `json:"created"` + Choices []struct { + Logprobs any `json:"logprobs"` + FinishReason string `json:"finish_reason"` + NativeFinishReason string `json:"native_finish_reason"` + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + Refusal any `json:"refusal"` + Reasoning any `json:"reasoning"` + } `json:"message"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +type DSResp struct { + ID string `json:"id"` + Choices []struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Object string `json:"object"` +} + +type LLMResp struct { + Index int `json:"index"` + Content string `json:"content"` + Tokens []any `json:"tokens"` + IDSlot int `json:"id_slot"` + Stop bool `json:"stop"` + Model string `json:"model"` + TokensPredicted int `json:"tokens_predicted"` + TokensEvaluated int `json:"tokens_evaluated"` + Prompt string `json:"prompt"` + HasNewLine bool `json:"has_new_line"` + Truncated bool `json:"truncated"` + StopType string `json:"stop_type"` + StoppingWord string `json:"stopping_word"` + TokensCached int `json:"tokens_cached"` +} diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..1aef72e --- /dev/null +++ b/parser.go @@ -0,0 +1,146 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "strings" +) + +type RespParser interface { + ParseBytes(body []byte) (string, error) + MakePayload(prompt string) io.Reader +} + +// DeepSeekParser: deepseek implementation of RespParser +type deepSeekParser struct { + log *slog.Logger +} + +func NewDeepSeekParser(log *slog.Logger) *deepSeekParser { + return &deepSeekParser{log: log} +} + +func (p *deepSeekParser) ParseBytes(body []byte) (string, error) { + // parsing logic here + dsResp := DSResp{} + if err := json.Unmarshal(body, &dsResp); err != nil { + p.log.Error("failed to unmarshall", "error", err) + return "", err + } + if len(dsResp.Choices) == 0 { + p.log.Error("empty choices", "dsResp", dsResp) + err := errors.New("empty choices in dsResp") + return "", err + } + text := dsResp.Choices[0].Text + return text, nil +} + +func (p *deepSeekParser) MakePayload(prompt string) io.Reader { + return strings.NewReader(fmt.Sprintf(`{ + "model": "deepseek-chat", + "prompt": "%s", + "echo": false, + "frequency_penalty": 0, + "logprobs": 0, + "max_tokens": 1024, + "presence_penalty": 0, + "stop": null, + "stream": false, + "stream_options": null, + "suffix": null, + "temperature": 1, + "top_p": 1 + }`, prompt)) +} + +// llama.cpp implementation of RespParser +type lcpRespParser struct { + log *slog.Logger +} + +func NewLCPRespParser(log *slog.Logger) *lcpRespParser { + return &lcpRespParser{log: log} +} + +func (p *lcpRespParser) ParseBytes(body []byte) (string, error) { + // parsing logic here + resp := LLMResp{} + if err := json.Unmarshal(body, &resp); err != nil { + p.log.Error("failed to unmarshal", "error", err) + return "", err + } + return resp.Content, nil +} + +func (p *lcpRespParser) MakePayload(prompt string) io.Reader { + return strings.NewReader(fmt.Sprintf(`{ + "model": "local-model", + "prompt": "%s", + "frequency_penalty": 0, + "max_tokens": 1024, + "stop": ["Q:\n", "A:\n"], + "stream": false, + "temperature": 0.4, + "top_p": 1 + }`, prompt)) +} + +type openRouterParser struct { + log *slog.Logger + modelIndex uint32 +} + +func NewOpenRouterParser(log *slog.Logger) *openRouterParser { + return &openRouterParser{ + log: log, + modelIndex: 0, + } +} + +func (p *openRouterParser) ParseBytes(body []byte) (string, error) { + // parsing logic here + resp := OpenRouterResp{} + if err := json.Unmarshal(body, &resp); err != nil { + p.log.Error("failed to unmarshal", "error", err) + return "", err + } + if len(resp.Choices) == 0 { + p.log.Error("empty choices", "resp", resp) + err := errors.New("empty choices in resp") + return "", err + } + text := resp.Choices[0].Message.Content + return text, nil +} + +func (p *openRouterParser) MakePayload(prompt string) io.Reader { + // Models to rotate through + models := []string{ + "google/gemini-2.0-flash-exp:free", + "deepseek/deepseek-chat-v3-0324:free", + "mistralai/mistral-small-3.2-24b-instruct:free", + "qwen/qwen3-14b:free", + "deepseek/deepseek-r1:free", + "google/gemma-3-27b-it:free", + "meta-llama/llama-3.3-70b-instruct:free", + } + // Get next model index using atomic addition for thread safety + p.modelIndex++ + model := models[int(p.modelIndex)%len(models)] + strPayload := fmt.Sprintf(`{ + "model": "%s", + "max_tokens": 300, + "messages": [ + { + "role": "user", + "content": "%s" + } + ] + }`, model, prompt) + p.log.Debug("made openrouter payload", "model", model, "payload", strPayload) + return strings.NewReader(strPayload) +}