Feat: llm call retries and model switch
This commit is contained in:
		| @@ -14,6 +14,7 @@ import ( | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| @@ -427,27 +428,64 @@ func (b *Bot) BuildPrompt(room *models.Room) string { | ||||
|  | ||||
| func (b *Bot) CallLLM(prompt string) ([]byte, error) { | ||||
| 	method := "POST" | ||||
| 	payload := b.LLMParser.MakePayload(prompt) | ||||
| 	// Generate the payload once as bytes | ||||
| 	payloadReader := b.LLMParser.MakePayload(prompt) | ||||
| 	client := &http.Client{} | ||||
| 	req, err := http.NewRequest(method, b.cfg.LLMConfig.URL, payload) | ||||
| 	if err != nil { | ||||
| 		b.log.Error("failed to make new request", "error", err, "url", b.cfg.LLMConfig.URL) | ||||
| 		return nil, err | ||||
| 	maxRetries := 6 | ||||
| 	baseDelay := 2 // seconds | ||||
| 	for attempt := 0; attempt < maxRetries; attempt++ { | ||||
| 		// Create a new request for the attempt | ||||
| 		req, err := http.NewRequest(method, b.cfg.LLMConfig.URL, payloadReader) | ||||
| 		if err != nil { | ||||
| 			if attempt == maxRetries-1 { | ||||
| 				return nil, fmt.Errorf("failed to create request: %w", err) | ||||
| 			} | ||||
| 			b.log.Error("failed to make new request; will retry", "error", err, "url", b.cfg.LLMConfig.URL, "attempt", attempt) | ||||
| 			time.Sleep(time.Duration(baseDelay) * time.Second * time.Duration(attempt+1)) | ||||
| 			continue | ||||
| 		} | ||||
| 		req.Header.Add("Content-Type", "application/json") | ||||
| 		req.Header.Add("Accept", "application/json") | ||||
| 		req.Header.Add("Authorization", "Bearer "+b.cfg.LLMConfig.TOKEN) | ||||
| 		resp, err := client.Do(req) | ||||
| 		if err != nil { | ||||
| 			if attempt == maxRetries-1 { | ||||
| 				return nil, fmt.Errorf("http request failed: %w", err) | ||||
| 			} | ||||
| 			b.log.Error("http request failed; will retry", "error", err, "url", b.cfg.LLMConfig.URL, "attempt", attempt) | ||||
| 			delay := time.Duration(baseDelay*(attempt+1)) * time.Second | ||||
| 			time.Sleep(delay) | ||||
| 			continue | ||||
| 		} | ||||
| 		body, err := io.ReadAll(resp.Body) | ||||
| 		resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			if attempt == maxRetries-1 { | ||||
| 				return nil, fmt.Errorf("failed to read response body: %w", err) | ||||
| 			} | ||||
| 			b.log.Error("failed to read response body; will retry", "error", err, "url", b.cfg.LLMConfig.URL, "attempt", attempt) | ||||
| 			delay := time.Duration(baseDelay*(attempt+1)) * time.Second | ||||
| 			time.Sleep(delay) | ||||
| 			continue | ||||
| 		} | ||||
| 		// Check status code | ||||
| 		if resp.StatusCode >= 400 && resp.StatusCode < 600 { | ||||
| 			if attempt == maxRetries-1 { | ||||
| 				return nil, fmt.Errorf("after %d retries, still got status %d", maxRetries, resp.StatusCode) | ||||
| 			} | ||||
| 			b.log.Warn("retriable status code; will retry", "code", resp.StatusCode, "attempt", attempt) | ||||
| 			delay := time.Duration((baseDelay * (1 << attempt))) * time.Second | ||||
| 			time.Sleep(delay) | ||||
| 			continue | ||||
| 		} | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			// For non-retriable errors, return immediately | ||||
| 			return nil, fmt.Errorf("non-retriable status %d, body: %s", resp.StatusCode, string(body)) | ||||
| 		} | ||||
| 		// Success | ||||
| 		b.log.Debug("llm resp", "body", string(body), "url", b.cfg.LLMConfig.URL, "attempt", attempt) | ||||
| 		return body, nil | ||||
| 	} | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 	req.Header.Add("Accept", "application/json") | ||||
| 	req.Header.Add("Authorization", "Bearer "+b.cfg.LLMConfig.TOKEN) | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		b.log.Error("failed to make request", "error", err, "url", b.cfg.LLMConfig.URL) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		b.log.Error("failed to read resp body", "error", err, "url", b.cfg.LLMConfig.URL) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	b.log.Debug("llm resp", "body", string(body), "url", b.cfg.LLMConfig.URL) | ||||
| 	return body, nil | ||||
| 	// This line should not be reached because each error path returns in the loop. | ||||
| 	return nil, fmt.Errorf("unknown error in retry loop") | ||||
| } | ||||
|   | ||||
| @@ -86,12 +86,6 @@ func (p *lcpRespParser) ParseBytes(body []byte) (map[string]any, error) { | ||||
| 		p.log.Error("failed to unmarshal", "error", err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	// if len(resp.Choices) == 0 { | ||||
| 	// 	p.log.Error("empty choices", "resp", resp) | ||||
| 	// 	err := errors.New("empty choices in resp") | ||||
| 	// 	return nil, err | ||||
| 	// } | ||||
| 	// text := resp.Choices[0].Message.Content | ||||
| 	text := resp.Content | ||||
| 	li := strings.Index(text, "{") | ||||
| 	ri := strings.LastIndex(text, "}") | ||||
| @@ -123,11 +117,15 @@ func (p *lcpRespParser) MakePayload(prompt string) io.Reader { | ||||
| } | ||||
|  | ||||
| type openRouterParser struct { | ||||
| 	log *slog.Logger | ||||
| 	log        *slog.Logger | ||||
| 	modelIndex uint32 | ||||
| } | ||||
|  | ||||
| func NewOpenRouterParser(log *slog.Logger) *openRouterParser { | ||||
| 	return &openRouterParser{log: log} | ||||
| 	return &openRouterParser{ | ||||
| 		log:        log, | ||||
| 		modelIndex: 0, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *openRouterParser) ParseBytes(body []byte) (map[string]any, error) { | ||||
| @@ -160,17 +158,28 @@ func (p *openRouterParser) ParseBytes(body []byte) (map[string]any, error) { | ||||
| } | ||||
|  | ||||
| func (p *openRouterParser) MakePayload(prompt string) io.Reader { | ||||
| 	// "model": "deepseek/deepseek-chat-v3-0324:free", | ||||
| 	// TODO: set list of models an option to pick on the frontend | ||||
| 	// Models to rotate through | ||||
| 	models := []string{ | ||||
| 		"google/gemini-2.0-flash-exp:free", | ||||
| 		"deepseek/deepseek-chat-v3-0324:free", | ||||
| 		"mistralai/mistral-small-3.2-24b-instruct:free", | ||||
| 		"qwen/qwen3-14b:free", | ||||
| 		"deepseek/deepseek-r1:free", | ||||
| 		"google/gemma-3-27b-it:free", | ||||
| 		"meta-llama/llama-3.3-70b-instruct:free", | ||||
| 	} | ||||
| 	// Get next model index using atomic addition for thread safety | ||||
| 	p.modelIndex++ | ||||
| 	model := models[int(p.modelIndex)%len(models)] | ||||
| 	strPayload := fmt.Sprintf(`{ | ||||
| 	"model": "google/gemini-2.0-flash-exp:free", | ||||
| 	"model": "%s", | ||||
| 	"messages": [ | ||||
| 		{ | ||||
| 		"role": "user", | ||||
| 		"content": "%s" | ||||
| 		} | ||||
| 	] | ||||
| 	}`, prompt) | ||||
| 	p.log.Debug("made openrouter payload", "payload", strPayload) | ||||
| 	}`, model, prompt) | ||||
| 	p.log.Debug("made openrouter payload", "model", model, "payload", strPayload) | ||||
| 	return strings.NewReader(strPayload) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Grail Finder
					Grail Finder