181 lines
4.7 KiB
Go
181 lines
4.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"grailbench/models"
|
|
"io"
|
|
"log/slog"
|
|
)
|
|
|
|
type RespParser interface {
|
|
ParseBytes(body []byte) (string, error)
|
|
MakePayload(prompt string) io.Reader
|
|
}
|
|
|
|
// DeepSeekParser: deepseek implementation of RespParser
|
|
type deepSeekParser struct {
|
|
log *slog.Logger
|
|
}
|
|
|
|
func NewDeepSeekParser(log *slog.Logger) *deepSeekParser {
|
|
return &deepSeekParser{log: log}
|
|
}
|
|
|
|
func (p *deepSeekParser) ParseBytes(body []byte) (string, error) {
|
|
// parsing logic here
|
|
dsResp := models.DSResp{}
|
|
if err := json.Unmarshal(body, &dsResp); err != nil {
|
|
p.log.Error("failed to unmarshall", "error", err)
|
|
return "", err
|
|
}
|
|
if len(dsResp.Choices) == 0 {
|
|
p.log.Error("empty choices", "dsResp", dsResp)
|
|
err := errors.New("empty choices in dsResp")
|
|
return "", err
|
|
}
|
|
text := dsResp.Choices[0].Text
|
|
return text, nil
|
|
}
|
|
|
|
func (p *deepSeekParser) MakePayload(prompt string) io.Reader {
|
|
payload := struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
Echo bool `json:"echo"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
Logprobs int `json:"logprobs"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
PresencePenalty float64 `json:"presence_penalty"`
|
|
Stop interface{} `json:"stop"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions interface{} `json:"stream_options"`
|
|
Suffix interface{} `json:"suffix"`
|
|
Temperature float64 `json:"temperature"`
|
|
NProbs int `json:"n_probs"`
|
|
TopP float64 `json:"top_p"`
|
|
}{
|
|
Model: "deepseek-chat",
|
|
Prompt: prompt,
|
|
Echo: false,
|
|
FrequencyPenalty: 0,
|
|
Logprobs: 0,
|
|
MaxTokens: 1024,
|
|
PresencePenalty: 0,
|
|
Stop: nil,
|
|
Stream: false,
|
|
StreamOptions: nil,
|
|
Suffix: nil,
|
|
Temperature: 1,
|
|
NProbs: 10,
|
|
TopP: 1,
|
|
}
|
|
b, err := json.Marshal(payload)
|
|
if err != nil {
|
|
p.log.Error("failed to marshal deepseek payload", "error", err)
|
|
return nil
|
|
}
|
|
return bytes.NewReader(b)
|
|
}
|
|
|
|
// llama.cpp implementation of RespParser
|
|
type lcpRespParser struct {
|
|
log *slog.Logger
|
|
}
|
|
|
|
func NewLCPRespParser(log *slog.Logger) *lcpRespParser {
|
|
return &lcpRespParser{log: log}
|
|
}
|
|
|
|
func (p *lcpRespParser) ParseBytes(body []byte) (string, error) {
|
|
// parsing logic here
|
|
resp := models.LLMResp{}
|
|
if err := json.Unmarshal(body, &resp); err != nil {
|
|
p.log.Error("failed to unmarshal", "error", err)
|
|
return "", err
|
|
}
|
|
return resp.Content, nil
|
|
}
|
|
|
|
func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
|
payload := struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Stop []string `json:"stop"`
|
|
Stream bool `json:"stream"`
|
|
Temperature float64 `json:"temperature"`
|
|
TopP float64 `json:"top_p"`
|
|
}{
|
|
Model: "local-model",
|
|
Prompt: prompt,
|
|
FrequencyPenalty: 0,
|
|
MaxTokens: 1024,
|
|
Stop: []string{"Q:\n", "A:\n"},
|
|
Stream: false,
|
|
Temperature: 0.4,
|
|
TopP: 1,
|
|
}
|
|
|
|
b, err := json.Marshal(payload)
|
|
if err != nil {
|
|
// This should not happen for this struct, but good practice to handle.
|
|
p.log.Error("failed to marshal lcp payload", "error", err)
|
|
return nil
|
|
}
|
|
return bytes.NewReader(b)
|
|
}
|
|
|
|
type openRouterParser struct {
|
|
log *slog.Logger
|
|
modelIndex uint32
|
|
}
|
|
|
|
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
|
return &openRouterParser{
|
|
log: log,
|
|
modelIndex: 0,
|
|
}
|
|
}
|
|
|
|
func (p *openRouterParser) ParseBytes(body []byte) (string, error) {
|
|
// parsing logic here
|
|
resp := models.DSResp{}
|
|
if err := json.Unmarshal(body, &resp); err != nil {
|
|
p.log.Error("failed to unmarshal", "error", err)
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
p.log.Error("empty choices", "resp", resp)
|
|
err := errors.New("empty choices in resp")
|
|
return "", err
|
|
}
|
|
text := resp.Choices[0].Text
|
|
return text, nil
|
|
}
|
|
|
|
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
|
// Models to rotate through
|
|
// TODO: to config
|
|
model := "deepseek/deepseek-r1:free"
|
|
// Get next model index using atomic addition for thread safety
|
|
p.modelIndex++
|
|
payload := struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
}{
|
|
Model: model,
|
|
Prompt: prompt,
|
|
}
|
|
|
|
b, err := json.Marshal(payload)
|
|
if err != nil {
|
|
p.log.Error("failed to marshal openrouter payload", "error", err)
|
|
return nil
|
|
}
|
|
p.log.Debug("made openrouter payload", "model", model, "payload", string(b))
|
|
return bytes.NewReader(b)
|
|
}
|