Feat: llm call retries and model switch
This commit is contained in:
@ -14,6 +14,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -427,27 +428,64 @@ func (b *Bot) BuildPrompt(room *models.Room) string {
|
||||
|
||||
func (b *Bot) CallLLM(prompt string) ([]byte, error) {
|
||||
method := "POST"
|
||||
payload := b.LLMParser.MakePayload(prompt)
|
||||
// Generate the payload once as bytes
|
||||
payloadReader := b.LLMParser.MakePayload(prompt)
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(method, b.cfg.LLMConfig.URL, payload)
|
||||
if err != nil {
|
||||
b.log.Error("failed to make new request", "error", err, "url", b.cfg.LLMConfig.URL)
|
||||
return nil, err
|
||||
maxRetries := 6
|
||||
baseDelay := 2 // seconds
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
// Create a new request for the attempt
|
||||
req, err := http.NewRequest(method, b.cfg.LLMConfig.URL, payloadReader)
|
||||
if err != nil {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
b.log.Error("failed to make new request; will retry", "error", err, "url", b.cfg.LLMConfig.URL, "attempt", attempt)
|
||||
time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1))
|
||||
continue
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Authorization", "Bearer "+b.cfg.LLMConfig.TOKEN)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
b.log.Error("http request failed; will retry", "error", err, "url", b.cfg.LLMConfig.URL, "attempt", attempt)
|
||||
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
b.log.Error("failed to read response body; will retry", "error", err, "url", b.cfg.LLMConfig.URL, "attempt", attempt)
|
||||
delay := time.Duration(baseDelay*(attempt+1)) * time.Second
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
// Check status code
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 600 {
|
||||
if attempt == maxRetries-1 {
|
||||
return nil, fmt.Errorf("after %d retries, still got status %d", maxRetries, resp.StatusCode)
|
||||
}
|
||||
b.log.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt)
|
||||
delay := time.Duration((baseDelay * (1 << attempt))) * time.Second
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// For non-retriable errors, return immediately
|
||||
return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
// Success
|
||||
b.log.Debug("llm resp", "body", string(body), "url", b.cfg.LLMConfig.URL, "attempt", attempt)
|
||||
return body, nil
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Authorization", "Bearer "+b.cfg.LLMConfig.TOKEN)
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
b.log.Error("failed to make request", "error", err, "url", b.cfg.LLMConfig.URL)
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
b.log.Error("failed to read resp body", "error", err, "url", b.cfg.LLMConfig.URL)
|
||||
return nil, err
|
||||
}
|
||||
b.log.Debug("llm resp", "body", string(body), "url", b.cfg.LLMConfig.URL)
|
||||
return body, nil
|
||||
// This line should not be reached because each error path returns in the loop.
|
||||
return nil, fmt.Errorf("unknown error in retry loop")
|
||||
}
|
||||
|
@ -86,12 +86,6 @@ func (p *lcpRespParser) ParseBytes(body []byte) (map[string]any, error) {
|
||||
p.log.Error("failed to unmarshal", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
// if len(resp.Choices) == 0 {
|
||||
// p.log.Error("empty choices", "resp", resp)
|
||||
// err := errors.New("empty choices in resp")
|
||||
// return nil, err
|
||||
// }
|
||||
// text := resp.Choices[0].Message.Content
|
||||
text := resp.Content
|
||||
li := strings.Index(text, "{")
|
||||
ri := strings.LastIndex(text, "}")
|
||||
@ -123,11 +117,15 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader {
|
||||
}
|
||||
|
||||
type openRouterParser struct {
|
||||
log *slog.Logger
|
||||
log *slog.Logger
|
||||
modelIndex uint32
|
||||
}
|
||||
|
||||
func NewOpenRouterParser(log *slog.Logger) *openRouterParser {
|
||||
return &openRouterParser{log: log}
|
||||
return &openRouterParser{
|
||||
log: log,
|
||||
modelIndex: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openRouterParser) ParseBytes(body []byte) (map[string]any, error) {
|
||||
@ -160,17 +158,28 @@ func (p *openRouterParser) ParseBytes(body []byte) (map[string]any, error) {
|
||||
}
|
||||
|
||||
func (p *openRouterParser) MakePayload(prompt string) io.Reader {
|
||||
// "model": "deepseek/deepseek-chat-v3-0324:free",
|
||||
// TODO: set list of models an option to pick on the frontend
|
||||
// Models to rotate through
|
||||
models := []string{
|
||||
"google/gemini-2.0-flash-exp:free",
|
||||
"deepseek/deepseek-chat-v3-0324:free",
|
||||
"mistralai/mistral-small-3.2-24b-instruct:free",
|
||||
"qwen/qwen3-14b:free",
|
||||
"deepseek/deepseek-r1:free",
|
||||
"google/gemma-3-27b-it:free",
|
||||
"meta-llama/llama-3.3-70b-instruct:free",
|
||||
}
|
||||
// Get next model index using atomic addition for thread safety
|
||||
p.modelIndex++
|
||||
model := models[int(p.modelIndex)%len(models)]
|
||||
strPayload := fmt.Sprintf(`{
|
||||
"model": "google/gemini-2.0-flash-exp:free",
|
||||
"model": "%s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "%s"
|
||||
}
|
||||
]
|
||||
}`, prompt)
|
||||
p.log.Debug("made openrouter payload", "payload", strPayload)
|
||||
}`, model, prompt)
|
||||
p.log.Debug("made openrouter payload", "model", model, "payload", strPayload)
|
||||
return strings.NewReader(strPayload)
|
||||
}
|
||||
|
Reference in New Issue
Block a user