diff --git a/main.go b/main.go index 552988d..ce658e6 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "time" "grailbench/config" + "grailbench/models" ) var ( @@ -20,26 +21,13 @@ var ( currentModel = "" ) -type Question struct { - ID string `json:"id"` - Topic string `json:"topic"` - Question string `json:"question"` -} - -type Answer struct { - Q Question - Answer string `json:"answer"` - Model string `json:"model"` - // resp time? -} - -func loadQuestions(fp string) ([]Question, error) { +func loadQuestions(fp string) ([]models.Question, error) { data, err := os.ReadFile(fp) if err != nil { logger.Error("failed to read file", "error", err, "fp", fp) return nil, err } - resp := []Question{} + resp := []models.Question{} if err := json.Unmarshal(data, &resp); err != nil { logger.Error("failed to unmarshal file", "error", err, "fp", fp) return nil, err @@ -77,8 +65,8 @@ func main() { } } -func runBench(questions []Question) ([]Answer, error) { - answers := []Answer{} +func runBench(questions []models.Question) ([]models.Answer, error) { + answers := []models.Answer{} for _, q := range questions { resp, err := callLLM(buildPrompt(q.Question)) if err != nil { @@ -90,7 +78,7 @@ func runBench(questions []Question) ([]Answer, error) { if err != nil { panic(err) } - a := Answer{Q: q, Answer: respText, Model: currentModel} + a := models.Answer{Q: q, Answer: respText, Model: currentModel} answers = append(answers, a) } return answers, nil diff --git a/models.go b/models.go index a8ee45c..06ab7d0 100644 --- a/models.go +++ b/models.go @@ -1,56 +1 @@ package main - -type OpenRouterResp struct { - ID string `json:"id"` - Provider string `json:"provider"` - Model string `json:"model"` - Object string `json:"object"` - Created int `json:"created"` - Choices []struct { - Logprobs any `json:"logprobs"` - FinishReason string `json:"finish_reason"` - NativeFinishReason string `json:"native_finish_reason"` - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - Refusal any `json:"refusal"` - Reasoning any `json:"reasoning"` - } `json:"message"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -type DSResp struct { - ID string `json:"id"` - Choices []struct { - Text string `json:"text"` - Index int `json:"index"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Created int `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Object string `json:"object"` -} - -type LLMResp struct { - Index int `json:"index"` - Content string `json:"content"` - Tokens []any `json:"tokens"` - IDSlot int `json:"id_slot"` - Stop bool `json:"stop"` - Model string `json:"model"` - TokensPredicted int `json:"tokens_predicted"` - TokensEvaluated int `json:"tokens_evaluated"` - Prompt string `json:"prompt"` - HasNewLine bool `json:"has_new_line"` - Truncated bool `json:"truncated"` - StopType string `json:"stop_type"` - StoppingWord string `json:"stopping_word"` - TokensCached int `json:"tokens_cached"` -} diff --git a/models/models.go b/models/models.go new file mode 100644 index 0000000..842a6eb --- /dev/null +++ b/models/models.go @@ -0,0 +1,69 @@ +package models + +type Question struct { + ID string `json:"id"` + Topic string `json:"topic"` + Question string `json:"question"` +} + +type Answer struct { + Q Question + Answer string `json:"answer"` + Model string `json:"model"` + // resp time? +} + +type OpenRouterResp struct { + ID string `json:"id"` + Provider string `json:"provider"` + Model string `json:"model"` + Object string `json:"object"` + Created int `json:"created"` + Choices []struct { + Logprobs any `json:"logprobs"` + FinishReason string `json:"finish_reason"` + NativeFinishReason string `json:"native_finish_reason"` + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + Refusal any `json:"refusal"` + Reasoning any `json:"reasoning"` + } `json:"message"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +type DSResp struct { + ID string `json:"id"` + Choices []struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Object string `json:"object"` +} + +type LLMResp struct { + Index int `json:"index"` + Content string `json:"content"` + Tokens []any `json:"tokens"` + IDSlot int `json:"id_slot"` + Stop bool `json:"stop"` + Model string `json:"model"` + TokensPredicted int `json:"tokens_predicted"` + TokensEvaluated int `json:"tokens_evaluated"` + Prompt string `json:"prompt"` + HasNewLine bool `json:"has_new_line"` + Truncated bool `json:"truncated"` + StopType string `json:"stop_type"` + StoppingWord string `json:"stopping_word"` + TokensCached int `json:"tokens_cached"` +} diff --git a/parser.go b/parser.go index 1aef72e..127dab7 100644 --- a/parser.go +++ b/parser.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "grailbench/models" "io" "log/slog" "strings" @@ -25,7 +26,7 @@ func NewDeepSeekParser(log *slog.Logger) *deepSeekParser { func (p *deepSeekParser) ParseBytes(body []byte) (string, error) { // parsing logic here - dsResp := DSResp{} + dsResp := models.DSResp{} if err := json.Unmarshal(body, &dsResp); err != nil { p.log.Error("failed to unmarshall", "error", err) return "", err @@ -68,7 +69,7 @@ func NewLCPRespParser(log *slog.Logger) *lcpRespParser { func (p *lcpRespParser) ParseBytes(body []byte) (string, error) { // parsing logic here - resp := LLMResp{} + resp := models.LLMResp{} if err := json.Unmarshal(body, &resp); err != nil { p.log.Error("failed to unmarshal", "error", err) return "", err @@ -103,7 +104,7 @@ func NewOpenRouterParser(log *slog.Logger) *openRouterParser { func (p *openRouterParser) ParseBytes(body []byte) (string, error) { // parsing logic here - resp := OpenRouterResp{} + resp := models.OpenRouterResp{} if err := json.Unmarshal(body, &resp); err != nil { p.log.Error("failed to unmarshal", "error", err) return "", err