Fix (race): mutex chatbody

This commit is contained in:
Grail Finder
2026-03-07 10:46:18 +03:00
parent 014e297ae3
commit a842b00e96
9 changed files with 422 additions and 142 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"strings"
"sync"
)
type FuncCall struct {
@@ -639,3 +640,253 @@ type MultimodalToolResp struct {
Type string `json:"type"`
Parts []map[string]string `json:"parts"`
}
// SafeChatBody is a thread-safe wrapper around ChatBody using RWMutex.
// This allows safe concurrent access to chat state from multiple goroutines.
type SafeChatBody struct {
mu sync.RWMutex
ChatBody
}
// NewSafeChatBody creates a new SafeChatBody from an existing ChatBody.
// If cb is nil, creates an empty ChatBody.
func NewSafeChatBody(cb *ChatBody) *SafeChatBody {
if cb == nil {
return &SafeChatBody{
ChatBody: ChatBody{
Messages: []RoleMsg{},
},
}
}
return &SafeChatBody{
ChatBody: *cb,
}
}
// GetModel returns the model name (thread-safe read).
func (s *SafeChatBody) GetModel() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Model
}
// SetModel sets the model name (thread-safe write).
func (s *SafeChatBody) SetModel(model string) {
s.mu.Lock()
defer s.mu.Unlock()
s.Model = model
}
// GetStream returns the stream flag (thread-safe read).
func (s *SafeChatBody) GetStream() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Stream
}
// SetStream sets the stream flag (thread-safe write).
func (s *SafeChatBody) SetStream(stream bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.Stream = stream
}
// GetMessages returns a copy of all messages (thread-safe read).
// Returns a copy to prevent race conditions after the lock is released.
func (s *SafeChatBody) GetMessages() []RoleMsg {
s.mu.RLock()
defer s.mu.RUnlock()
// Return a copy to prevent external modification
messagesCopy := make([]RoleMsg, len(s.Messages))
copy(messagesCopy, s.Messages)
return messagesCopy
}
// SetMessages replaces all messages (thread-safe write).
func (s *SafeChatBody) SetMessages(messages []RoleMsg) {
s.mu.Lock()
defer s.mu.Unlock()
s.Messages = messages
}
// AppendMessage adds a message to the end (thread-safe write).
func (s *SafeChatBody) AppendMessage(msg RoleMsg) {
s.mu.Lock()
defer s.mu.Unlock()
s.Messages = append(s.Messages, msg)
}
// GetMessageAt returns a message at a specific index (thread-safe read).
// Returns the message and a boolean indicating if the index was valid.
func (s *SafeChatBody) GetMessageAt(index int) (RoleMsg, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if index < 0 || index >= len(s.Messages) {
return RoleMsg{}, false
}
return s.Messages[index], true
}
// SetMessageAt updates a message at a specific index (thread-safe write).
// Returns false if index is out of bounds.
func (s *SafeChatBody) SetMessageAt(index int, msg RoleMsg) bool {
s.mu.Lock()
defer s.mu.Unlock()
if index < 0 || index >= len(s.Messages) {
return false
}
s.Messages[index] = msg
return true
}
// GetLastMessage returns the last message (thread-safe read).
// Returns the message and a boolean indicating if the chat has messages.
func (s *SafeChatBody) GetLastMessage() (RoleMsg, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.Messages) == 0 {
return RoleMsg{}, false
}
return s.Messages[len(s.Messages)-1], true
}
// GetMessageCount returns the number of messages (thread-safe read).
func (s *SafeChatBody) GetMessageCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.Messages)
}
// RemoveLastMessage removes the last message (thread-safe write).
// Returns false if there are no messages.
func (s *SafeChatBody) RemoveLastMessage() bool {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.Messages) == 0 {
return false
}
s.Messages = s.Messages[:len(s.Messages)-1]
return true
}
// TruncateMessages keeps only the first n messages (thread-safe write).
func (s *SafeChatBody) TruncateMessages(n int) {
s.mu.Lock()
defer s.mu.Unlock()
if n < len(s.Messages) {
s.Messages = s.Messages[:n]
}
}
// ClearMessages removes all messages (thread-safe write).
func (s *SafeChatBody) ClearMessages() {
s.mu.Lock()
defer s.mu.Unlock()
s.Messages = []RoleMsg{}
}
// Rename renames all occurrences of oldname to newname in messages (thread-safe read-modify-write).
func (s *SafeChatBody) Rename(oldname, newname string) {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.Messages {
s.Messages[i].Content = strings.ReplaceAll(s.Messages[i].Content, oldname, newname)
s.Messages[i].Role = strings.ReplaceAll(s.Messages[i].Role, oldname, newname)
}
}
// ListRoles returns all unique roles in messages (thread-safe read).
func (s *SafeChatBody) ListRoles() []string {
s.mu.RLock()
defer s.mu.RUnlock()
namesMap := make(map[string]struct{})
for i := range s.Messages {
namesMap[s.Messages[i].Role] = struct{}{}
}
resp := make([]string, len(namesMap))
i := 0
for k := range namesMap {
resp[i] = k
i++
}
return resp
}
// MakeStopSlice returns stop strings for all roles (thread-safe read).
func (s *SafeChatBody) MakeStopSlice() []string {
return s.MakeStopSliceExcluding("", s.ListRoles())
}
// MakeStopSliceExcluding returns stop strings excluding a specific role (thread-safe read).
func (s *SafeChatBody) MakeStopSliceExcluding(excludeRole string, roleList []string) []string {
s.mu.RLock()
defer s.mu.RUnlock()
ss := []string{}
for _, role := range roleList {
if role == excludeRole {
continue
}
ss = append(ss,
role+":\n",
role+":",
role+": ",
role+": ",
role+": \n",
role+": ",
)
}
return ss
}
// UpdateMessageFunc updates a message at index using a provided function.
// The function receives the current message and returns the updated message.
// This is atomic and thread-safe (read-modify-write under single lock).
// Returns false if index is out of bounds.
func (s *SafeChatBody) UpdateMessageFunc(index int, updater func(RoleMsg) RoleMsg) bool {
s.mu.Lock()
defer s.mu.Unlock()
if index < 0 || index >= len(s.Messages) {
return false
}
s.Messages[index] = updater(s.Messages[index])
return true
}
// AppendMessageFunc appends a new message created by a provided function.
// The function receives the current message count and returns the new message.
// This is atomic and thread-safe.
func (s *SafeChatBody) AppendMessageFunc(creator func(count int) RoleMsg) {
s.mu.Lock()
defer s.mu.Unlock()
msg := creator(len(s.Messages))
s.Messages = append(s.Messages, msg)
}
// GetMessagesForLLM returns a filtered copy of messages for sending to LLM.
// This is thread-safe and returns a copy safe for external modification.
func (s *SafeChatBody) GetMessagesForLLM(filterFunc func([]RoleMsg) []RoleMsg) []RoleMsg {
s.mu.RLock()
defer s.mu.RUnlock()
if filterFunc == nil {
messagesCopy := make([]RoleMsg, len(s.Messages))
copy(messagesCopy, s.Messages)
return messagesCopy
}
return filterFunc(s.Messages)
}
// WithLock executes a function while holding the write lock.
// Use this for complex operations that need to be atomic.
func (s *SafeChatBody) WithLock(fn func(*ChatBody)) {
s.mu.Lock()
defer s.mu.Unlock()
fn(&s.ChatBody)
}
// WithRLock executes a function while holding the read lock.
// Use this for complex read-only operations.
func (s *SafeChatBody) WithRLock(fn func(*ChatBody)) {
s.mu.RLock()
defer s.mu.RUnlock()
fn(&s.ChatBody)
}