Fix (race): mutex chatbody
This commit is contained in:
251
models/models.go
251
models/models.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FuncCall struct {
|
||||
@@ -639,3 +640,253 @@ type MultimodalToolResp struct {
|
||||
Type string `json:"type"`
|
||||
Parts []map[string]string `json:"parts"`
|
||||
}
|
||||
|
||||
// SafeChatBody is a thread-safe wrapper around ChatBody using RWMutex.
|
||||
// This allows safe concurrent access to chat state from multiple goroutines.
|
||||
type SafeChatBody struct {
|
||||
mu sync.RWMutex
|
||||
ChatBody
|
||||
}
|
||||
|
||||
// NewSafeChatBody creates a new SafeChatBody from an existing ChatBody.
|
||||
// If cb is nil, creates an empty ChatBody.
|
||||
func NewSafeChatBody(cb *ChatBody) *SafeChatBody {
|
||||
if cb == nil {
|
||||
return &SafeChatBody{
|
||||
ChatBody: ChatBody{
|
||||
Messages: []RoleMsg{},
|
||||
},
|
||||
}
|
||||
}
|
||||
return &SafeChatBody{
|
||||
ChatBody: *cb,
|
||||
}
|
||||
}
|
||||
|
||||
// GetModel returns the model name (thread-safe read).
|
||||
func (s *SafeChatBody) GetModel() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.Model
|
||||
}
|
||||
|
||||
// SetModel sets the model name (thread-safe write).
|
||||
func (s *SafeChatBody) SetModel(model string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Model = model
|
||||
}
|
||||
|
||||
// GetStream returns the stream flag (thread-safe read).
|
||||
func (s *SafeChatBody) GetStream() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.Stream
|
||||
}
|
||||
|
||||
// SetStream sets the stream flag (thread-safe write).
|
||||
func (s *SafeChatBody) SetStream(stream bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Stream = stream
|
||||
}
|
||||
|
||||
// GetMessages returns a copy of all messages (thread-safe read).
|
||||
// Returns a copy to prevent race conditions after the lock is released.
|
||||
func (s *SafeChatBody) GetMessages() []RoleMsg {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
// Return a copy to prevent external modification
|
||||
messagesCopy := make([]RoleMsg, len(s.Messages))
|
||||
copy(messagesCopy, s.Messages)
|
||||
return messagesCopy
|
||||
}
|
||||
|
||||
// SetMessages replaces all messages (thread-safe write).
|
||||
func (s *SafeChatBody) SetMessages(messages []RoleMsg) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Messages = messages
|
||||
}
|
||||
|
||||
// AppendMessage adds a message to the end (thread-safe write).
|
||||
func (s *SafeChatBody) AppendMessage(msg RoleMsg) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Messages = append(s.Messages, msg)
|
||||
}
|
||||
|
||||
// GetMessageAt returns a message at a specific index (thread-safe read).
|
||||
// Returns the message and a boolean indicating if the index was valid.
|
||||
func (s *SafeChatBody) GetMessageAt(index int) (RoleMsg, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if index < 0 || index >= len(s.Messages) {
|
||||
return RoleMsg{}, false
|
||||
}
|
||||
return s.Messages[index], true
|
||||
}
|
||||
|
||||
// SetMessageAt updates a message at a specific index (thread-safe write).
|
||||
// Returns false if index is out of bounds.
|
||||
func (s *SafeChatBody) SetMessageAt(index int, msg RoleMsg) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if index < 0 || index >= len(s.Messages) {
|
||||
return false
|
||||
}
|
||||
s.Messages[index] = msg
|
||||
return true
|
||||
}
|
||||
|
||||
// GetLastMessage returns the last message (thread-safe read).
|
||||
// Returns the message and a boolean indicating if the chat has messages.
|
||||
func (s *SafeChatBody) GetLastMessage() (RoleMsg, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if len(s.Messages) == 0 {
|
||||
return RoleMsg{}, false
|
||||
}
|
||||
return s.Messages[len(s.Messages)-1], true
|
||||
}
|
||||
|
||||
// GetMessageCount returns the number of messages (thread-safe read).
|
||||
func (s *SafeChatBody) GetMessageCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.Messages)
|
||||
}
|
||||
|
||||
// RemoveLastMessage removes the last message (thread-safe write).
|
||||
// Returns false if there are no messages.
|
||||
func (s *SafeChatBody) RemoveLastMessage() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.Messages) == 0 {
|
||||
return false
|
||||
}
|
||||
s.Messages = s.Messages[:len(s.Messages)-1]
|
||||
return true
|
||||
}
|
||||
|
||||
// TruncateMessages keeps only the first n messages (thread-safe write).
|
||||
func (s *SafeChatBody) TruncateMessages(n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if n < len(s.Messages) {
|
||||
s.Messages = s.Messages[:n]
|
||||
}
|
||||
}
|
||||
|
||||
// ClearMessages removes all messages (thread-safe write).
|
||||
func (s *SafeChatBody) ClearMessages() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Messages = []RoleMsg{}
|
||||
}
|
||||
|
||||
// Rename renames all occurrences of oldname to newname in messages (thread-safe read-modify-write).
|
||||
func (s *SafeChatBody) Rename(oldname, newname string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for i := range s.Messages {
|
||||
s.Messages[i].Content = strings.ReplaceAll(s.Messages[i].Content, oldname, newname)
|
||||
s.Messages[i].Role = strings.ReplaceAll(s.Messages[i].Role, oldname, newname)
|
||||
}
|
||||
}
|
||||
|
||||
// ListRoles returns all unique roles in messages (thread-safe read).
|
||||
func (s *SafeChatBody) ListRoles() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
namesMap := make(map[string]struct{})
|
||||
for i := range s.Messages {
|
||||
namesMap[s.Messages[i].Role] = struct{}{}
|
||||
}
|
||||
resp := make([]string, len(namesMap))
|
||||
i := 0
|
||||
for k := range namesMap {
|
||||
resp[i] = k
|
||||
i++
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// MakeStopSlice returns stop strings for all roles (thread-safe read).
|
||||
func (s *SafeChatBody) MakeStopSlice() []string {
|
||||
return s.MakeStopSliceExcluding("", s.ListRoles())
|
||||
}
|
||||
|
||||
// MakeStopSliceExcluding returns stop strings excluding a specific role (thread-safe read).
|
||||
func (s *SafeChatBody) MakeStopSliceExcluding(excludeRole string, roleList []string) []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
ss := []string{}
|
||||
for _, role := range roleList {
|
||||
if role == excludeRole {
|
||||
continue
|
||||
}
|
||||
ss = append(ss,
|
||||
role+":\n",
|
||||
role+":",
|
||||
role+": ",
|
||||
role+": ",
|
||||
role+": \n",
|
||||
role+": ",
|
||||
)
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// UpdateMessageFunc updates a message at index using a provided function.
|
||||
// The function receives the current message and returns the updated message.
|
||||
// This is atomic and thread-safe (read-modify-write under single lock).
|
||||
// Returns false if index is out of bounds.
|
||||
func (s *SafeChatBody) UpdateMessageFunc(index int, updater func(RoleMsg) RoleMsg) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if index < 0 || index >= len(s.Messages) {
|
||||
return false
|
||||
}
|
||||
s.Messages[index] = updater(s.Messages[index])
|
||||
return true
|
||||
}
|
||||
|
||||
// AppendMessageFunc appends a new message created by a provided function.
|
||||
// The function receives the current message count and returns the new message.
|
||||
// This is atomic and thread-safe.
|
||||
func (s *SafeChatBody) AppendMessageFunc(creator func(count int) RoleMsg) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
msg := creator(len(s.Messages))
|
||||
s.Messages = append(s.Messages, msg)
|
||||
}
|
||||
|
||||
// GetMessagesForLLM returns a filtered copy of messages for sending to LLM.
|
||||
// This is thread-safe and returns a copy safe for external modification.
|
||||
func (s *SafeChatBody) GetMessagesForLLM(filterFunc func([]RoleMsg) []RoleMsg) []RoleMsg {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if filterFunc == nil {
|
||||
messagesCopy := make([]RoleMsg, len(s.Messages))
|
||||
copy(messagesCopy, s.Messages)
|
||||
return messagesCopy
|
||||
}
|
||||
return filterFunc(s.Messages)
|
||||
}
|
||||
|
||||
// WithLock executes a function while holding the write lock.
|
||||
// Use this for complex operations that need to be atomic.
|
||||
func (s *SafeChatBody) WithLock(fn func(*ChatBody)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
fn(&s.ChatBody)
|
||||
}
|
||||
|
||||
// WithRLock executes a function while holding the read lock.
|
||||
// Use this for complex read-only operations.
|
||||
func (s *SafeChatBody) WithRLock(fn func(*ChatBody)) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
fn(&s.ChatBody)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user