2 Commits

Author SHA1 Message Date
Grail Finder
8c4d01ab3b Enha: atomic global vars instead of mutexes 2026-03-07 11:26:07 +03:00
Grail Finder
a842b00e96 Fix (race): mutex chatbody 2026-03-07 10:46:18 +03:00
15 changed files with 579 additions and 695 deletions

View File

@@ -1,4 +1,4 @@
.PHONY: setconfig run lint lintall install-linters setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run installdelve checkdelve fetch-onnx install-onnx-deps fetch-kokoro-voices install-espeak .PHONY: setconfig run lint lintall install-linters setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run installdelve checkdelve fetch-onnx install-onnx-deps
run: setconfig run: setconfig
go build -tags extra -o gf-lt && ./gf-lt go build -tags extra -o gf-lt && ./gf-lt
@@ -33,9 +33,6 @@ lintall: lint
fetch-onnx: fetch-onnx:
mkdir -p onnx/embedgemma && curl -o onnx/embedgemma/config.json -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/config.json && curl -o onnx/embedgemma/tokenizer.json -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/tokenizer.json && curl -o onnx/embedgemma/model_q4.onnx -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/onnx/model_q4.onnx && curl -o onnx/embedgemma/model_q4.onnx_data -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/onnx/model_q4.onnx_data?download=true mkdir -p onnx/embedgemma && curl -o onnx/embedgemma/config.json -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/config.json && curl -o onnx/embedgemma/tokenizer.json -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/tokenizer.json && curl -o onnx/embedgemma/model_q4.onnx -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/onnx/model_q4.onnx && curl -o onnx/embedgemma/model_q4.onnx_data -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/onnx/model_q4.onnx_data?download=true
fetch-kokoro-onnx:
mkdir -p onnx/kokoro && curl -o onnx/kokoro/config.json -L https://huggingface.co/onnx-community/Kokoro-82M-v1.0-ONNX/resolve/main/config.json && curl -o onnx/kokoro/tokenizer.json -L https://huggingface.co/onnx-community/Kokoro-82M-v1.0-ONNX/resolve/main/tokenizer.json && curl -o onnx/kokoro/model_quantized.onnx -L https://huggingface.co/onnx-community/Kokoro-82M-v1.0-ONNX/resolve/main/onnx/model_quantized.onnx && curl -o onnx/kokoro/voices.bin -L https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/voices-v1.0.bin
install-onnx-deps: ## Install ONNX Runtime with CUDA support (or CPU fallback) install-onnx-deps: ## Install ONNX Runtime with CUDA support (or CPU fallback)
@echo "=== ONNX Runtime Installer ===" && \ @echo "=== ONNX Runtime Installer ===" && \
echo "" && \ echo "" && \
@@ -197,25 +194,3 @@ docker-logs-whisper: ## View logs from Whisper STT service only
docker-logs-kokoro: ## View logs from Kokoro TTS service only docker-logs-kokoro: ## View logs from Kokoro TTS service only
@echo "Displaying logs from Kokoro TTS service..." @echo "Displaying logs from Kokoro TTS service..."
docker-compose -f batteries/docker-compose.yml logs -f kokoro-tts docker-compose -f batteries/docker-compose.yml logs -f kokoro-tts
# Kokoro ONNX TTS Setup
install-espeak: ## Install espeak-ng for phoneme tokenization
@echo "=== Installing espeak-ng ===" && \
if command -v espeak-ng >/dev/null 2>&1; then \
echo "espeak-ng is already installed:" && \
espeak-ng --version && \
exit 0; \
fi && \
echo "Installing espeak-ng..." && \
sudo apt-get update && \
sudo apt-get install -y espeak-ng espeak && \
echo "espeak-ng installed successfully!" && \
espeak-ng --version
fetch-kokoro-voices: ## Download Kokoro voice files (PyTorch format)
@echo "=== Downloading Kokoro voices ===" && \
mkdir -p onnx/kokoro/voices && \
echo "Downloading af_bella voice..." && \
curl -L -o onnx/kokoro/voices/af_bella.pt https://raw.githubusercontent.com/hexgrad/kokoro/main/kokoro/voices/af_heart.pt && \
echo "Voice file downloaded to onnx/kokoro/voices/" && \
ls -lh onnx/kokoro/voices/

160
bot.go
View File

@@ -22,7 +22,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync/atomic"
"time" "time"
) )
@@ -37,7 +37,7 @@ var (
chunkChan = make(chan string, 10) chunkChan = make(chan string, 10)
openAIToolChan = make(chan string, 10) openAIToolChan = make(chan string, 10)
streamDone = make(chan bool, 1) streamDone = make(chan bool, 1)
chatBody *models.ChatBody chatBody *models.SafeChatBody
store storage.FullRepo store storage.FullRepo
defaultFirstMsg = "Hello! What can I do for you?" defaultFirstMsg = "Hello! What can I do for you?"
defaultStarter = []models.RoleMsg{} defaultStarter = []models.RoleMsg{}
@@ -49,7 +49,6 @@ var (
//nolint:unused // TTS_ENABLED conditionally uses this //nolint:unused // TTS_ENABLED conditionally uses this
orator Orator orator Orator
asr STT asr STT
localModelsMu sync.RWMutex
defaultLCPProps = map[string]float32{ defaultLCPProps = map[string]float32{
"temperature": 0.8, "temperature": 0.8,
"dry_multiplier": 0.0, "dry_multiplier": 0.0,
@@ -64,11 +63,17 @@ var (
"google/gemma-3-27b-it:free", "google/gemma-3-27b-it:free",
"meta-llama/llama-3.3-70b-instruct:free", "meta-llama/llama-3.3-70b-instruct:free",
} }
LocalModels = []string{} LocalModels atomic.Value // stores []string
localModelsData *models.LCPModels localModelsData atomic.Value // stores *models.LCPModels
orModelsData *models.ORModels orModelsData atomic.Value // stores *models.ORModels
) )
func init() {
LocalModels.Store([]string{})
localModelsData.Store((*models.LCPModels)(nil))
orModelsData.Store((*models.ORModels)(nil))
}
var thinkBlockRE = regexp.MustCompile(`(?s)<think>.*?</think>`) var thinkBlockRE = regexp.MustCompile(`(?s)<think>.*?</think>`)
// parseKnownToTag extracts known_to list from content using configured tag. // parseKnownToTag extracts known_to list from content using configured tag.
@@ -262,13 +267,13 @@ func warmUpModel() {
return return
} }
// Check if model is already loaded // Check if model is already loaded
loaded, err := isModelLoaded(chatBody.Model) loaded, err := isModelLoaded(chatBody.GetModel())
if err != nil { if err != nil {
logger.Debug("failed to check model status", "model", chatBody.Model, "error", err) logger.Debug("failed to check model status", "model", chatBody.GetModel(), "error", err)
// Continue with warmup attempt anyway // Continue with warmup attempt anyway
} }
if loaded { if loaded {
showToast("model already loaded", "Model "+chatBody.Model+" is already loaded.") showToast("model already loaded", "Model "+chatBody.GetModel()+" is already loaded.")
return return
} }
go func() { go func() {
@@ -277,7 +282,7 @@ func warmUpModel() {
switch { switch {
case strings.HasSuffix(cfg.CurrentAPI, "/completion"): case strings.HasSuffix(cfg.CurrentAPI, "/completion"):
// Old completion endpoint // Old completion endpoint
req := models.NewLCPReq(".", chatBody.Model, nil, map[string]float32{ req := models.NewLCPReq(".", chatBody.GetModel(), nil, map[string]float32{
"temperature": 0.8, "temperature": 0.8,
"dry_multiplier": 0.0, "dry_multiplier": 0.0,
"min_p": 0.05, "min_p": 0.05,
@@ -289,7 +294,7 @@ func warmUpModel() {
// OpenAI-compatible chat endpoint // OpenAI-compatible chat endpoint
req := models.OpenAIReq{ req := models.OpenAIReq{
ChatBody: &models.ChatBody{ ChatBody: &models.ChatBody{
Model: chatBody.Model, Model: chatBody.GetModel(),
Messages: []models.RoleMsg{ Messages: []models.RoleMsg{
{Role: "system", Content: "."}, {Role: "system", Content: "."},
}, },
@@ -313,7 +318,7 @@ func warmUpModel() {
} }
resp.Body.Close() resp.Body.Close()
// Start monitoring for model load completion // Start monitoring for model load completion
monitorModelLoad(chatBody.Model) monitorModelLoad(chatBody.GetModel())
}() }()
} }
@@ -356,7 +361,7 @@ func fetchORModels(free bool) ([]string, error) {
if err := json.NewDecoder(resp.Body).Decode(data); err != nil { if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, err return nil, err
} }
orModelsData = data orModelsData.Store(data)
freeModels := data.ListModels(free) freeModels := data.ListModels(free)
return freeModels, nil return freeModels, nil
} }
@@ -418,9 +423,7 @@ func fetchLCPModelsWithStatus() (*models.LCPModels, error) {
if err := json.NewDecoder(resp.Body).Decode(data); err != nil { if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, err return nil, err
} }
localModelsMu.Lock() localModelsData.Store(data)
localModelsData = data
localModelsMu.Unlock()
return data, nil return data, nil
} }
@@ -823,10 +826,10 @@ func chatRound(r *models.ChatRoundReq) error {
} }
go sendMsgToLLM(reader) go sendMsgToLLM(reader)
logger.Debug("looking at vars in chatRound", "msg", r.UserMsg, "regen", r.Regen, "resume", r.Resume) logger.Debug("looking at vars in chatRound", "msg", r.UserMsg, "regen", r.Regen, "resume", r.Resume)
msgIdx := len(chatBody.Messages) msgIdx := chatBody.GetMessageCount()
if !r.Resume { if !r.Resume {
// Add empty message to chatBody immediately so it persists during Alt+T toggle // Add empty message to chatBody immediately so it persists during Alt+T toggle
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{ chatBody.AppendMessage(models.RoleMsg{
Role: botPersona, Content: "", Role: botPersona, Content: "",
}) })
nl := "\n\n" nl := "\n\n"
@@ -838,7 +841,7 @@ func chatRound(r *models.ChatRoundReq) error {
} }
fmt.Fprintf(textView, "%s[-:-:b](%d) %s[-:-:-]\n", nl, msgIdx, roleToIcon(botPersona)) fmt.Fprintf(textView, "%s[-:-:b](%d) %s[-:-:-]\n", nl, msgIdx, roleToIcon(botPersona))
} else { } else {
msgIdx = len(chatBody.Messages) - 1 msgIdx = chatBody.GetMessageCount() - 1
} }
respText := strings.Builder{} respText := strings.Builder{}
toolResp := strings.Builder{} toolResp := strings.Builder{}
@@ -895,7 +898,10 @@ out:
fmt.Fprint(textView, chunk) fmt.Fprint(textView, chunk)
respText.WriteString(chunk) respText.WriteString(chunk)
// Update the message in chatBody.Messages so it persists during Alt+T // Update the message in chatBody.Messages so it persists during Alt+T
chatBody.Messages[msgIdx].Content = respText.String() chatBody.UpdateMessageFunc(msgIdx, func(msg models.RoleMsg) models.RoleMsg {
msg.Content = respText.String()
return msg
})
if scrollToEndEnabled { if scrollToEndEnabled {
textView.ScrollToEnd() textView.ScrollToEnd()
} }
@@ -938,29 +944,32 @@ out:
} }
botRespMode = false botRespMode = false
if r.Resume { if r.Resume {
chatBody.Messages[len(chatBody.Messages)-1].Content += respText.String() chatBody.UpdateMessageFunc(chatBody.GetMessageCount()-1, func(msg models.RoleMsg) models.RoleMsg {
updatedMsg := chatBody.Messages[len(chatBody.Messages)-1] msg.Content += respText.String()
processedMsg := processMessageTag(&updatedMsg) processedMsg := processMessageTag(&msg)
chatBody.Messages[len(chatBody.Messages)-1] = *processedMsg if msgStats != nil && processedMsg.Role != cfg.ToolRole {
if msgStats != nil && chatBody.Messages[len(chatBody.Messages)-1].Role != cfg.ToolRole { processedMsg.Stats = msgStats
chatBody.Messages[len(chatBody.Messages)-1].Stats = msgStats }
} return *processedMsg
})
} else { } else {
chatBody.Messages[msgIdx].Content = respText.String() chatBody.UpdateMessageFunc(msgIdx, func(msg models.RoleMsg) models.RoleMsg {
processedMsg := processMessageTag(&chatBody.Messages[msgIdx]) msg.Content = respText.String()
chatBody.Messages[msgIdx] = *processedMsg processedMsg := processMessageTag(&msg)
if msgStats != nil && chatBody.Messages[msgIdx].Role != cfg.ToolRole { if msgStats != nil && processedMsg.Role != cfg.ToolRole {
chatBody.Messages[msgIdx].Stats = msgStats processedMsg.Stats = msgStats
} }
stopTTSIfNotForUser(&chatBody.Messages[msgIdx]) return *processedMsg
})
stopTTSIfNotForUser(&chatBody.GetMessages()[msgIdx])
} }
cleanChatBody() cleanChatBody()
refreshChatDisplay() refreshChatDisplay()
updateStatusLine() updateStatusLine()
// bot msg is done; // bot msg is done;
// now check it for func call // now check it for func call
// logChat(activeChatName, chatBody.Messages) // logChat(activeChatName, chatBody.GetMessages())
if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { if err := updateStorageChat(activeChatName, chatBody.GetMessages()); err != nil {
logger.Warn("failed to update storage", "error", err, "name", activeChatName) logger.Warn("failed to update storage", "error", err, "name", activeChatName)
} }
// Strip think blocks before parsing for tool calls // Strip think blocks before parsing for tool calls
@@ -975,8 +984,8 @@ out:
// If so, trigger those characters to respond if that char is not controlled by user // If so, trigger those characters to respond if that char is not controlled by user
// perhaps we should have narrator role to determine which char is next to act // perhaps we should have narrator role to determine which char is next to act
if cfg.AutoTurn { if cfg.AutoTurn {
lastMsg := chatBody.Messages[len(chatBody.Messages)-1] lastMsg, ok := chatBody.GetLastMessage()
if len(lastMsg.KnownTo) > 0 { if ok && len(lastMsg.KnownTo) > 0 {
triggerPrivateMessageResponses(&lastMsg) triggerPrivateMessageResponses(&lastMsg)
} }
} }
@@ -985,13 +994,15 @@ out:
// cleanChatBody removes messages with null or empty content to prevent API issues // cleanChatBody removes messages with null or empty content to prevent API issues
func cleanChatBody() { func cleanChatBody() {
if chatBody == nil || chatBody.Messages == nil { if chatBody == nil || chatBody.GetMessageCount() == 0 {
return return
} }
// Tool request cleaning is now configurable via AutoCleanToolCallsFromCtx (default false) // Tool request cleaning is now configurable via AutoCleanToolCallsFromCtx (default false)
// /completion msg where part meant for user and other part tool call // /completion msg where part meant for user and other part tool call
// chatBody.Messages = cleanToolCalls(chatBody.Messages) // chatBody.Messages = cleanToolCalls(chatBody.Messages)
chatBody.Messages = consolidateAssistantMessages(chatBody.Messages) chatBody.WithLock(func(cb *models.ChatBody) {
cb.Messages = consolidateAssistantMessages(cb.Messages)
})
} }
// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings. // convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings.
@@ -1091,7 +1102,7 @@ func findCall(msg, toolCall string) bool {
Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err), Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err),
ToolCallID: lastToolCall.ID, // Use the stored tool call ID ToolCallID: lastToolCall.ID, // Use the stored tool call ID
} }
chatBody.Messages = append(chatBody.Messages, toolResponseMsg) chatBody.AppendMessage(toolResponseMsg)
// Clear the stored tool call ID after using it (no longer needed) // Clear the stored tool call ID after using it (no longer needed)
// Trigger the assistant to continue processing with the error message // Trigger the assistant to continue processing with the error message
crr := &models.ChatRoundReq{ crr := &models.ChatRoundReq{
@@ -1128,7 +1139,7 @@ func findCall(msg, toolCall string) bool {
Role: cfg.ToolRole, Role: cfg.ToolRole,
Content: "Error processing tool call: no valid JSON found. Please check the JSON format.", Content: "Error processing tool call: no valid JSON found. Please check the JSON format.",
} }
chatBody.Messages = append(chatBody.Messages, toolResponseMsg) chatBody.AppendMessage(toolResponseMsg)
crr := &models.ChatRoundReq{ crr := &models.ChatRoundReq{
Role: cfg.AssistantRole, Role: cfg.AssistantRole,
} }
@@ -1145,8 +1156,8 @@ func findCall(msg, toolCall string) bool {
Role: cfg.ToolRole, Role: cfg.ToolRole,
Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err), Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err),
} }
chatBody.Messages = append(chatBody.Messages, toolResponseMsg) chatBody.AppendMessage(toolResponseMsg)
logger.Debug("findCall: added tool error response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "message_count_after_add", len(chatBody.Messages)) logger.Debug("findCall: added tool error response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "message_count_after_add", chatBody.GetMessageCount())
// Trigger the assistant to continue processing with the error message // Trigger the assistant to continue processing with the error message
// chatRound("", cfg.AssistantRole, tv, false, false) // chatRound("", cfg.AssistantRole, tv, false, false)
crr := &models.ChatRoundReq{ crr := &models.ChatRoundReq{
@@ -1164,17 +1175,23 @@ func findCall(msg, toolCall string) bool {
// we got here => last msg recognized as a tool call (correct or not) // we got here => last msg recognized as a tool call (correct or not)
// Use the tool call ID from streaming response (lastToolCall.ID) // Use the tool call ID from streaming response (lastToolCall.ID)
// Don't generate random ID - the ID should match between assistant message and tool response // Don't generate random ID - the ID should match between assistant message and tool response
lastMsgIdx := len(chatBody.Messages) - 1 lastMsgIdx := chatBody.GetMessageCount() - 1
if lastToolCall.ID != "" { if lastToolCall.ID != "" {
chatBody.Messages[lastMsgIdx].ToolCallID = lastToolCall.ID chatBody.UpdateMessageFunc(lastMsgIdx, func(msg models.RoleMsg) models.RoleMsg {
msg.ToolCallID = lastToolCall.ID
return msg
})
} }
// Store tool call info in the assistant message // Store tool call info in the assistant message
// Convert Args map to JSON string for storage // Convert Args map to JSON string for storage
chatBody.Messages[lastMsgIdx].ToolCall = &models.ToolCall{ chatBody.UpdateMessageFunc(lastMsgIdx, func(msg models.RoleMsg) models.RoleMsg {
ID: lastToolCall.ID, msg.ToolCall = &models.ToolCall{
Name: lastToolCall.Name, ID: lastToolCall.ID,
Args: mapToString(lastToolCall.Args), Name: lastToolCall.Name,
} Args: mapToString(lastToolCall.Args),
}
return msg
})
// call a func // call a func
_, ok := fnMap[fc.Name] _, ok := fnMap[fc.Name]
if !ok { if !ok {
@@ -1185,8 +1202,8 @@ func findCall(msg, toolCall string) bool {
Content: m, Content: m,
ToolCallID: lastToolCall.ID, // Use the stored tool call ID ToolCallID: lastToolCall.ID, // Use the stored tool call ID
} }
chatBody.Messages = append(chatBody.Messages, toolResponseMsg) chatBody.AppendMessage(toolResponseMsg)
logger.Debug("findCall: added tool not implemented response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) logger.Debug("findCall: added tool not implemented response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", chatBody.GetMessageCount())
// Clear the stored tool call ID after using it // Clear the stored tool call ID after using it
lastToolCall.ID = "" lastToolCall.ID = ""
// Trigger the assistant to continue processing with the new tool response // Trigger the assistant to continue processing with the new tool response
@@ -1257,9 +1274,9 @@ func findCall(msg, toolCall string) bool {
} }
} }
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
"\n\n", len(chatBody.Messages), cfg.ToolRole, toolResponseMsg.GetText()) "\n\n", chatBody.GetMessageCount(), cfg.ToolRole, toolResponseMsg.GetText())
chatBody.Messages = append(chatBody.Messages, toolResponseMsg) chatBody.AppendMessage(toolResponseMsg)
logger.Debug("findCall: added actual tool response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) logger.Debug("findCall: added actual tool response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", chatBody.GetMessageCount())
// Clear the stored tool call ID after using it // Clear the stored tool call ID after using it
lastToolCall.ID = "" lastToolCall.ID = ""
// Trigger the assistant to continue processing with the new tool response // Trigger the assistant to continue processing with the new tool response
@@ -1389,7 +1406,7 @@ func charToStart(agentName string, keepSysP bool) bool {
func updateModelLists() { func updateModelLists() {
var err error var err error
if cfg.OpenRouterToken != "" { if cfg.OpenRouterToken != "" {
ORFreeModels, err = fetchORModels(true) _, err := fetchORModels(true)
if err != nil { if err != nil {
logger.Warn("failed to fetch or models", "error", err) logger.Warn("failed to fetch or models", "error", err)
} }
@@ -1399,22 +1416,19 @@ func updateModelLists() {
if err != nil { if err != nil {
logger.Warn("failed to fetch llama.cpp models", "error", err) logger.Warn("failed to fetch llama.cpp models", "error", err)
} }
localModelsMu.Lock() LocalModels.Store(ml)
LocalModels = ml
localModelsMu.Unlock()
for statusLineWidget == nil { for statusLineWidget == nil {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
} }
// set already loaded model in llama.cpp // set already loaded model in llama.cpp
if strings.Contains(cfg.CurrentAPI, "localhost") || strings.Contains(cfg.CurrentAPI, "127.0.0.1") { if strings.Contains(cfg.CurrentAPI, "localhost") || strings.Contains(cfg.CurrentAPI, "127.0.0.1") {
localModelsMu.Lock() modelList := LocalModels.Load().([]string)
defer localModelsMu.Unlock() for i := range modelList {
for i := range LocalModels { if strings.Contains(modelList[i], models.LoadedMark) {
if strings.Contains(LocalModels[i], models.LoadedMark) { m := strings.TrimPrefix(modelList[i], models.LoadedMark)
m := strings.TrimPrefix(LocalModels[i], models.LoadedMark)
cfg.CurrentModel = m cfg.CurrentModel = m
chatBody.Model = m chatBody.Model = m
cachedModelColor = "green" cachedModelColor.Store("green")
updateStatusLine() updateStatusLine()
updateToolCapabilities() updateToolCapabilities()
app.Draw() app.Draw()
@@ -1425,21 +1439,17 @@ func updateModelLists() {
} }
func refreshLocalModelsIfEmpty() { func refreshLocalModelsIfEmpty() {
localModelsMu.RLock() models := LocalModels.Load().([]string)
if len(LocalModels) > 0 { if len(models) > 0 {
localModelsMu.RUnlock()
return return
} }
localModelsMu.RUnlock()
// try to fetch // try to fetch
models, err := fetchLCPModels() models, err := fetchLCPModels()
if err != nil { if err != nil {
logger.Warn("failed to fetch llama.cpp models", "error", err) logger.Warn("failed to fetch llama.cpp models", "error", err)
return return
} }
localModelsMu.Lock() LocalModels.Store(models)
LocalModels = models
localModelsMu.Unlock()
} }
func summarizeAndStartNewChat() { func summarizeAndStartNewChat() {
@@ -1523,11 +1533,11 @@ func init() {
} }
lastToolCall = &models.FuncCall{} lastToolCall = &models.FuncCall{}
lastChat := loadOldChatOrGetNew() lastChat := loadOldChatOrGetNew()
chatBody = &models.ChatBody{ chatBody = models.NewSafeChatBody(&models.ChatBody{
Model: "modelname", Model: "modelname",
Stream: true, Stream: true,
Messages: lastChat, Messages: lastChat,
} })
choseChunkParser() choseChunkParser()
httpClient = createClient(time.Second * 90) httpClient = createClient(time.Second * 90)
if cfg.TTS_ENABLED { if cfg.TTS_ENABLED {

View File

@@ -61,10 +61,6 @@ type Config struct {
TTS_SPEED float32 `toml:"TTS_SPEED"` TTS_SPEED float32 `toml:"TTS_SPEED"`
TTS_PROVIDER string `toml:"TTS_PROVIDER"` TTS_PROVIDER string `toml:"TTS_PROVIDER"`
TTS_LANGUAGE string `toml:"TTS_LANGUAGE"` TTS_LANGUAGE string `toml:"TTS_LANGUAGE"`
// Kokoro ONNX TTS
KokoroModelPath string `toml:"KokoroModelPath"`
KokoroVoicesPath string `toml:"KokoroVoicesPath"`
KokoroVoice string `toml:"KokoroVoice"`
// STT // STT
STT_TYPE string `toml:"STT_TYPE"` // WHISPER_SERVER, WHISPER_BINARY STT_TYPE string `toml:"STT_TYPE"` // WHISPER_SERVER, WHISPER_BINARY
STT_URL string `toml:"STT_URL"` STT_URL string `toml:"STT_URL"`

View File

@@ -1,421 +0,0 @@
//go:build extra
// +build extra
package extra
import (
"bytes"
"fmt"
"gf-lt/models"
"gf-lt/onnx"
"log/slog"
"os/exec"
"strings"
"sync"
"time"
"github.com/gopxl/beep/v2"
"github.com/gopxl/beep/v2/speaker"
"github.com/gopxl/beep/v2/wav"
"github.com/neurosnap/sentences/english"
"github.com/yalue/onnxruntime_go"
)
// KokoroONNXOrator implements Kokoro TTS using ONNX runtime
type KokoroONNXOrator struct {
logger *slog.Logger
mu sync.Mutex
session *onnxruntime_go.DynamicAdvancedSession
phonemeMap map[string]int
espeakCmd string
voice string
speed float32
styleVector []float32
currentStream *beep.Ctrl
currentDone chan bool
textBuffer strings.Builder
interrupt bool
modelLoaded bool
modelPath string
voicesPath string
}
// Phoneme to token ID mapping from Kokoro tokenizer.json
var kokoroPhonemeMap = map[string]int{
"$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23,
"A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42,
"a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68,
"ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148,
"ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177,
}
func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error {
if o.modelLoaded {
return nil
}
o.mu.Lock()
defer o.mu.Unlock()
if o.modelLoaded {
return nil
}
if modelPath == "" {
o.logger.Error("modelPath is empty, cannot load ONNX model")
return fmt.Errorf("modelPath is empty, set KokoroModelPath in config")
}
// Initialize ONNX runtime (shared with embedder)
if err := onnx.Init(); err != nil {
o.logger.Error("ONNX init failed", "error", err)
return fmt.Errorf("ONNX init failed: %w", err)
}
if onnx.HasCUDASupport() {
o.logger.Info("ONNX using CUDA")
} else {
o.logger.Info("ONNX using CPU fallback")
}
if o.phonemeMap == nil {
o.phonemeMap = kokoroPhonemeMap
}
if o.espeakCmd == "" {
o.espeakCmd = "espeak-ng"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
o.espeakCmd = "espeak"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng")
}
}
}
o.logger.Info("using espeak command", "cmd", o.espeakCmd)
// Load voice embedding if not already loaded
if o.styleVector == nil {
voiceName := o.voice
if voiceName == "" {
voiceName = "af_bella"
}
if o.voicesPath != "" {
styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName)
if err != nil {
o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName)
o.styleVector = make([]float32, 256)
} else {
// Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one)
// Actually, let's average across all 510 to get a single 256-dim vector
if len(styleVec) != 510*256 {
o.logger.Error("voice embedding has unexpected size", "len", len(styleVec))
err = fmt.Errorf("voice embedding has unexpected size", "len", len(styleVec))
return err
}
o.styleVector = make([]float32, 256)
for i := 0; i < 256; i++ {
var sum float32
for j := 0; j < 510; j++ {
sum += styleVec[j*256+i]
}
o.styleVector[i] = sum / 510.0
}
o.logger.Info("loaded voice embedding", "voice", voiceName)
}
} else {
o.logger.Warn("no voices path configured, using zeros for style")
o.styleVector = make([]float32, 256)
}
}
opts, err := onnx.NewSessionOptions()
if err != nil {
return fmt.Errorf("failed to create session options: %w", err)
}
defer func() { _ = opts.Destroy() }()
if onnx.HasCUDASupport() {
o.logger.Info("session options created with CUDA")
} else {
o.logger.Info("session options created with CPU")
}
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath,
[]string{"input_ids", "style", "speed"},
[]string{"waveform"},
opts,
)
if err != nil {
o.logger.Error("failed to create ONNX session", "error", err)
return fmt.Errorf("failed to create ONNX session: %w", err)
}
o.session = session
o.modelLoaded = true
o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath)
return nil
}
func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) {
cmd := exec.Command(o.espeakCmd, "-x", "-q", text)
output, err := cmd.Output()
if err != nil {
o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text)
return "", fmt.Errorf("espeak failed: %w", err)
}
phonemeStr := strings.TrimSpace(string(output))
return phonemeStr, nil
}
func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) {
if phonemeStr == "" {
o.logger.Error("empty phoneme string")
return nil, fmt.Errorf("empty phoneme string")
}
// Iterate over each character in the phoneme string
tokens := make([]int, 0)
for _, ch := range phonemeStr {
chStr := string(ch)
if tokenID, ok := o.phonemeMap[chStr]; ok {
tokens = append(tokens, tokenID)
}
}
if len(tokens) == 0 {
o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr)
return nil, fmt.Errorf("no valid phonemes mapped to tokens")
}
return tokens, nil
}
func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) {
if err := o.ensureInitialized(o.modelPath); err != nil {
o.logger.Error("ensureInitialized failed", "error", err)
return nil, err
}
phonemeStr, err := o.textToPhonemes(text)
if err != nil {
o.logger.Error("phoneme conversion failed", "error", err)
return nil, fmt.Errorf("phoneme conversion failed: %w", err)
}
tokens, err := o.phonemesToTokens(phonemeStr)
if err != nil {
o.logger.Error("token conversion failed", "error", err)
return nil, fmt.Errorf("token conversion failed: %w", err)
}
if len(tokens) > 510 {
return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens))
}
tokens = append([]int{0}, tokens...)
tokens = append(tokens, 0)
inputIDs := make([]int64, len(tokens))
for i, t := range tokens {
inputIDs[i] = int64(t)
}
inputTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, int64(len(inputIDs))),
inputIDs,
)
if err != nil {
o.logger.Error("failed to create input tensor", "error", err)
return nil, fmt.Errorf("failed to create input tensor: %w", err)
}
defer func() { _ = inputTensor.Destroy() }()
styleTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1, 256),
o.styleVector,
)
if err != nil {
o.logger.Error("failed to create style tensor", "error", err)
return nil, fmt.Errorf("failed to create style tensor: %w", err)
}
defer func() { _ = styleTensor.Destroy() }()
speedTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1),
[]float32{o.speed},
)
if err != nil {
o.logger.Error("failed to create speed tensor", "error", err)
return nil, fmt.Errorf("failed to create speed tensor: %w", err)
}
defer func() { _ = speedTensor.Destroy() }()
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, 512),
)
if err != nil {
o.logger.Error("failed to create output tensor", "error", err)
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer func() { _ = outputTensor.Destroy() }()
err = o.session.Run(
[]onnxruntime_go.Value{inputTensor, styleTensor, speedTensor},
[]onnxruntime_go.Value{outputTensor},
)
if err != nil {
o.logger.Error("ONNX inference failed", "error", err)
return nil, fmt.Errorf("ONNX inference failed: %w", err)
}
audioData := outputTensor.GetData()
if len(audioData) == 0 {
o.logger.Error("empty audio output from ONNX")
return nil, fmt.Errorf("empty audio output")
}
audio := make([]float32, len(audioData))
copy(audio, audioData)
return audio, nil
}
func (o *KokoroONNXOrator) Speak(text string) error {
audio, err := o.generateAudio(text)
if err != nil {
o.logger.Error("audio generation failed", "error", err)
return fmt.Errorf("audio generation failed: %w", err)
}
// Create streamer for encoding
encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) {
for i := range samples {
if i >= len(audio) {
return i, false
}
samples[i][0] = float64(audio[i])
samples[i][1] = float64(audio[i])
}
return len(audio), true
})
buf := &seekableBuffer{new(bytes.Buffer)}
err = wav.Encode(buf, encodeStreamer, beep.Format{
SampleRate: 24000,
NumChannels: 1,
Precision: 2,
})
if err != nil {
o.logger.Error("wav encoding failed", "error", err)
return fmt.Errorf("wav encoding failed: %w", err)
}
decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes()))
if err != nil {
o.logger.Error("wav decode failed", "error", err)
return fmt.Errorf("wav decode failed: %w", err)
}
defer decodedStreamer.Close()
if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
o.logger.Error("speaker init failed", "error", err)
return fmt.Errorf("speaker init failed: %w", err)
}
o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels)
done := make(chan bool)
o.mu.Lock()
o.currentDone = done
o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() {
o.mu.Lock()
close(done)
o.currentStream = nil
o.currentDone = nil
o.mu.Unlock()
})), Paused: false}
o.mu.Unlock()
speaker.Play(o.currentStream)
<-done
return nil
}
func (o *KokoroONNXOrator) Stop() {
speaker.Lock()
defer speaker.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if o.currentStream != nil {
o.currentStream.Streamer = nil
}
}
func (o *KokoroONNXOrator) GetLogger() *slog.Logger {
return o.logger
}
func (o *KokoroONNXOrator) stoproutine() {
for {
<-TTSDoneChan
o.Stop()
for len(TTSTextChan) > 0 {
<-TTSTextChan
}
o.mu.Lock()
o.textBuffer.Reset()
if o.currentDone != nil {
select {
case o.currentDone <- true:
default:
}
}
o.interrupt = true
o.mu.Unlock()
}
}
func (o *KokoroONNXOrator) readroutine() {
tokenizer, _ := english.NewSentenceTokenizer(nil)
for {
select {
case chunk := <-TTSTextChan:
o.mu.Lock()
o.interrupt = false
_, err := o.textBuffer.WriteString(chunk)
if err != nil {
o.logger.Warn("failed to write to buffer", "error", err)
o.mu.Unlock()
continue
}
text := o.textBuffer.String()
sentences := tokenizer.Tokenize(text)
if len(sentences) <= 1 {
o.mu.Unlock()
continue
}
completeSentences := sentences[:len(sentences)-1]
remaining := sentences[len(sentences)-1].Text
o.textBuffer.Reset()
o.textBuffer.WriteString(remaining)
o.mu.Unlock()
for _, sentence := range completeSentences {
o.mu.Lock()
interrupted := o.interrupt
o.mu.Unlock()
if interrupted {
return
}
cleanedText := models.CleanText(sentence.Text)
if cleanedText == "" {
continue
}
o.logger.Info("KokoroONNX speak", "text", cleanedText)
if err := o.Speak(cleanedText); err != nil {
o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err)
}
}
case <-TTSFlushChan:
if len(TTSTextChan) > 0 {
for chunk := range TTSTextChan {
o.mu.Lock()
_, err := o.textBuffer.WriteString(chunk)
o.mu.Unlock()
if err != nil {
continue
}
if len(TTSTextChan) == 0 {
break
}
}
}
o.mu.Lock()
remaining := o.textBuffer.String()
remaining = models.CleanText(remaining)
o.textBuffer.Reset()
o.mu.Unlock()
if remaining == "" {
continue
}
sentencesRem := tokenizer.Tokenize(remaining)
for _, rs := range sentencesRem {
o.mu.Lock()
interrupt := o.interrupt
o.mu.Unlock()
if interrupt {
break
}
if err := o.Speak(rs.Text); err != nil {
o.logger.Error("tts failed", "text", rs.Text, "error", err)
}
}
}
}
}

View File

@@ -32,14 +32,6 @@ var (
// endsWithPunctuation = regexp.MustCompile(`[;.!?]$`) // endsWithPunctuation = regexp.MustCompile(`[;.!?]$`)
) )
type seekableBuffer struct {
*bytes.Buffer
}
func (s *seekableBuffer) Seek(offset int64, whence int) (int64, error) {
return 0, nil
}
type Orator interface { type Orator interface {
Speak(text string) error Speak(text string) error
Stop() Stop()
@@ -202,18 +194,6 @@ func NewOrator(log *slog.Logger, cfg *config.Config) Orator {
go orator.readroutine() go orator.readroutine()
go orator.stoproutine() go orator.stoproutine()
return orator return orator
case "kokoro_onnx":
log.Info("Initializing Kokoro ONNX TTS", "modelPath", cfg.KokoroModelPath, "voicesPath", cfg.KokoroVoicesPath, "voice", cfg.KokoroVoice, "speed", cfg.TTS_SPEED)
orator := &KokoroONNXOrator{
logger: log,
modelPath: cfg.KokoroModelPath,
voicesPath: cfg.KokoroVoicesPath,
speed: cfg.TTS_SPEED,
voice: cfg.KokoroVoice,
}
go orator.readroutine()
go orator.stoproutine()
return orator
default: default:
language := cfg.TTS_LANGUAGE language := cfg.TTS_LANGUAGE
if language == "" { if language == "" {

View File

@@ -16,11 +16,17 @@ import (
"time" "time"
"unicode" "unicode"
"sync/atomic"
"github.com/rivo/tview" "github.com/rivo/tview"
) )
// Cached model color - updated by background goroutine // Cached model color - updated by background goroutine
var cachedModelColor string = "orange" var cachedModelColor atomic.Value // stores string
func init() {
cachedModelColor.Store("orange")
}
// startModelColorUpdater starts a background goroutine that periodically updates // startModelColorUpdater starts a background goroutine that periodically updates
// the cached model color. Only runs HTTP requests for local llama.cpp APIs. // the cached model color. Only runs HTTP requests for local llama.cpp APIs.
@@ -39,20 +45,20 @@ func startModelColorUpdater() {
// updateCachedModelColor updates the global cachedModelColor variable // updateCachedModelColor updates the global cachedModelColor variable
func updateCachedModelColor() { func updateCachedModelColor() {
if !isLocalLlamacpp() { if !isLocalLlamacpp() {
cachedModelColor = "orange" cachedModelColor.Store("orange")
return return
} }
// Check if model is loaded // Check if model is loaded
loaded, err := isModelLoaded(chatBody.Model) loaded, err := isModelLoaded(chatBody.GetModel())
if err != nil { if err != nil {
// On error, assume not loaded (red) // On error, assume not loaded (red)
cachedModelColor = "red" cachedModelColor.Store("red")
return return
} }
if loaded { if loaded {
cachedModelColor = "green" cachedModelColor.Store("green")
} else { } else {
cachedModelColor = "red" cachedModelColor.Store("red")
} }
} }
@@ -103,7 +109,7 @@ func refreshChatDisplay() {
viewingAs = cfg.WriteNextMsgAs viewingAs = cfg.WriteNextMsgAs
} }
// Filter messages for this character // Filter messages for this character
filteredMessages := filterMessagesForCharacter(chatBody.Messages, viewingAs) filteredMessages := filterMessagesForCharacter(chatBody.GetMessages(), viewingAs)
displayText := chatToText(filteredMessages, cfg.ShowSys) displayText := chatToText(filteredMessages, cfg.ShowSys)
textView.SetText(displayText) textView.SetText(displayText)
colorText() colorText()
@@ -217,8 +223,8 @@ func startNewChat(keepSysP bool) {
logger.Warn("no such sys msg", "name", cfg.AssistantRole) logger.Warn("no such sys msg", "name", cfg.AssistantRole)
} }
// set chat body // set chat body
chatBody.Messages = chatBody.Messages[:2] chatBody.TruncateMessages(2)
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
newChat := &models.Chat{ newChat := &models.Chat{
ID: id + 1, ID: id + 1,
Name: fmt.Sprintf("%d_%s", id+1, cfg.AssistantRole), Name: fmt.Sprintf("%d_%s", id+1, cfg.AssistantRole),
@@ -335,7 +341,7 @@ func isLocalLlamacpp() bool {
// The cached value is updated by a background goroutine every 5 seconds. // The cached value is updated by a background goroutine every 5 seconds.
// For non-local models, returns orange. For local llama.cpp models, returns green if loaded, red if not. // For non-local models, returns orange. For local llama.cpp models, returns green if loaded, red if not.
func getModelColor() string { func getModelColor() string {
return cachedModelColor return cachedModelColor.Load().(string)
} }
func makeStatusLine() string { func makeStatusLine() string {
@@ -370,7 +376,7 @@ func makeStatusLine() string {
// Get model color based on load status for local llama.cpp models // Get model color based on load status for local llama.cpp models
modelColor := getModelColor() modelColor := getModelColor()
statusLine := fmt.Sprintf(statusLineTempl, activeChatName, statusLine := fmt.Sprintf(statusLineTempl, activeChatName,
boolColors[cfg.ToolUse], modelColor, chatBody.Model, boolColors[cfg.SkipLLMResp], boolColors[cfg.ToolUse], modelColor, chatBody.GetModel(), boolColors[cfg.SkipLLMResp],
cfg.CurrentAPI, persona, botPersona) cfg.CurrentAPI, persona, botPersona)
if cfg.STT_ENABLED { if cfg.STT_ENABLED {
recordingS := fmt.Sprintf(" | [%s:-:b]voice recording[-:-:-] (ctrl+r)", recordingS := fmt.Sprintf(" | [%s:-:b]voice recording[-:-:-] (ctrl+r)",
@@ -396,11 +402,11 @@ func makeStatusLine() string {
} }
func getContextTokens() int { func getContextTokens() int {
if chatBody == nil || chatBody.Messages == nil { if chatBody == nil {
return 0 return 0
} }
total := 0 total := 0
messages := chatBody.Messages messages := chatBody.GetMessages()
for i := range messages { for i := range messages {
msg := &messages[i] msg := &messages[i]
if msg.Stats != nil && msg.Stats.Tokens > 0 { if msg.Stats != nil && msg.Stats.Tokens > 0 {
@@ -415,46 +421,54 @@ func getContextTokens() int {
const deepseekContext = 128000 const deepseekContext = 128000
func getMaxContextTokens() int { func getMaxContextTokens() int {
if chatBody == nil || chatBody.Model == "" { if chatBody == nil || chatBody.GetModel() == "" {
return 0 return 0
} }
modelName := chatBody.Model modelName := chatBody.GetModel()
switch { switch {
case strings.Contains(cfg.CurrentAPI, "openrouter"): case strings.Contains(cfg.CurrentAPI, "openrouter"):
if orModelsData != nil { ord := orModelsData.Load()
for i := range orModelsData.Data { if ord != nil {
m := &orModelsData.Data[i] data := ord.(*models.ORModels)
if m.ID == modelName { if data != nil {
return m.ContextLength for i := range data.Data {
m := &data.Data[i]
if m.ID == modelName {
return m.ContextLength
}
} }
} }
} }
case strings.Contains(cfg.CurrentAPI, "deepseek"): case strings.Contains(cfg.CurrentAPI, "deepseek"):
return deepseekContext return deepseekContext
default: default:
if localModelsData != nil { lmd := localModelsData.Load()
for i := range localModelsData.Data { if lmd != nil {
m := &localModelsData.Data[i] data := lmd.(*models.LCPModels)
if m.ID == modelName { if data != nil {
for _, arg := range m.Status.Args { for i := range data.Data {
if strings.HasPrefix(arg, "--ctx-size") { m := &data.Data[i]
if strings.Contains(arg, "=") { if m.ID == modelName {
val := strings.Split(arg, "=")[1] for _, arg := range m.Status.Args {
if n, err := strconv.Atoi(val); err == nil { if strings.HasPrefix(arg, "--ctx-size") {
return n if strings.Contains(arg, "=") {
} val := strings.Split(arg, "=")[1]
} else { if n, err := strconv.Atoi(val); err == nil {
idx := -1
for j, a := range m.Status.Args {
if a == "--ctx-size" && j+1 < len(m.Status.Args) {
idx = j + 1
break
}
}
if idx != -1 {
if n, err := strconv.Atoi(m.Status.Args[idx]); err == nil {
return n return n
} }
} else {
idx := -1
for j, a := range m.Status.Args {
if a == "--ctx-size" && j+1 < len(m.Status.Args) {
idx = j + 1
break
}
}
if idx != -1 {
if n, err := strconv.Atoi(m.Status.Args[idx]); err == nil {
return n
}
}
} }
} }
} }
@@ -490,7 +504,7 @@ func listChatRoles() []string {
func deepseekModelValidator() error { func deepseekModelValidator() error {
if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI { if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" { if chatBody.GetModel() != "deepseek-chat" && chatBody.GetModel() != "deepseek-reasoner" {
showToast("bad request", "wrong deepseek model name") showToast("bad request", "wrong deepseek model name")
return nil return nil
} }
@@ -567,13 +581,13 @@ func executeCommandAndDisplay(cmdText string) {
outputContent := workingDir outputContent := workingDir
// Add the command being executed to the chat // Add the command being executed to the chat
fmt.Fprintf(textView, "\n[-:-:b](%d) <%s>: [-:-:-]\n$ %s\n", fmt.Fprintf(textView, "\n[-:-:b](%d) <%s>: [-:-:-]\n$ %s\n",
len(chatBody.Messages), cfg.ToolRole, cmdText) chatBody.GetMessageCount(), cfg.ToolRole, cmdText)
fmt.Fprintf(textView, "%s\n", outputContent) fmt.Fprintf(textView, "%s\n", outputContent)
combinedMsg := models.RoleMsg{ combinedMsg := models.RoleMsg{
Role: cfg.ToolRole, Role: cfg.ToolRole,
Content: "$ " + cmdText + "\n\n" + outputContent, Content: "$ " + cmdText + "\n\n" + outputContent,
} }
chatBody.Messages = append(chatBody.Messages, combinedMsg) chatBody.AppendMessage(combinedMsg)
if scrollToEndEnabled { if scrollToEndEnabled {
textView.ScrollToEnd() textView.ScrollToEnd()
} }
@@ -582,13 +596,13 @@ func executeCommandAndDisplay(cmdText string) {
} else { } else {
outputContent := "cd: " + newDir + ": No such file or directory" outputContent := "cd: " + newDir + ": No such file or directory"
fmt.Fprintf(textView, "\n[-:-:b](%d) <%s>: [-:-:-]\n$ %s\n", fmt.Fprintf(textView, "\n[-:-:b](%d) <%s>: [-:-:-]\n$ %s\n",
len(chatBody.Messages), cfg.ToolRole, cmdText) chatBody.GetMessageCount(), cfg.ToolRole, cmdText)
fmt.Fprintf(textView, "[red]%s[-:-:-]\n", outputContent) fmt.Fprintf(textView, "[red]%s[-:-:-]\n", outputContent)
combinedMsg := models.RoleMsg{ combinedMsg := models.RoleMsg{
Role: cfg.ToolRole, Role: cfg.ToolRole,
Content: "$ " + cmdText + "\n\n" + outputContent, Content: "$ " + cmdText + "\n\n" + outputContent,
} }
chatBody.Messages = append(chatBody.Messages, combinedMsg) chatBody.AppendMessage(combinedMsg)
if scrollToEndEnabled { if scrollToEndEnabled {
textView.ScrollToEnd() textView.ScrollToEnd()
} }
@@ -604,7 +618,7 @@ func executeCommandAndDisplay(cmdText string) {
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
// Add the command being executed to the chat // Add the command being executed to the chat
fmt.Fprintf(textView, "\n[-:-:b](%d) <%s>: [-:-:-]\n$ %s\n", fmt.Fprintf(textView, "\n[-:-:b](%d) <%s>: [-:-:-]\n$ %s\n",
len(chatBody.Messages), cfg.ToolRole, cmdText) chatBody.GetMessageCount(), cfg.ToolRole, cmdText)
var outputContent string var outputContent string
if err != nil { if err != nil {
// Include both output and error // Include both output and error
@@ -635,7 +649,7 @@ func executeCommandAndDisplay(cmdText string) {
Role: cfg.ToolRole, Role: cfg.ToolRole,
Content: combinedContent, Content: combinedContent,
} }
chatBody.Messages = append(chatBody.Messages, combinedMsg) chatBody.AppendMessage(combinedMsg)
// Scroll to end and update colors // Scroll to end and update colors
if scrollToEndEnabled { if scrollToEndEnabled {
textView.ScrollToEnd() textView.ScrollToEnd()
@@ -665,7 +679,7 @@ func performSearch(term string) {
searchResultLengths = nil searchResultLengths = nil
originalTextForSearch = "" originalTextForSearch = ""
// Re-render text without highlights // Re-render text without highlights
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
return return
} }

55
llm.go
View File

@@ -13,8 +13,9 @@ var lastImg string // for ctrl+j
// containsToolSysMsg checks if the toolSysMsg already exists in the chat body // containsToolSysMsg checks if the toolSysMsg already exists in the chat body
func containsToolSysMsg() bool { func containsToolSysMsg() bool {
for i := range chatBody.Messages { messages := chatBody.GetMessages()
if chatBody.Messages[i].Role == cfg.ToolRole && chatBody.Messages[i].Content == toolSysMsg { for i := range messages {
if messages[i].Role == cfg.ToolRole && messages[i].Content == toolSysMsg {
return true return true
} }
} }
@@ -135,13 +136,13 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = models.RoleMsg{Role: role, Content: msg} newMsg = models.RoleMsg{Role: role, Content: msg}
} }
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.AppendMessage(newMsg)
} }
// sending description of the tools and how to use them // sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) chatBody.AppendMessage(models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
} }
filteredMessages, botPersona := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, botPersona := filterMessagesForCurrentCharacter(chatBody.GetMessages())
// Build prompt and extract images inline as we process each message // Build prompt and extract images inline as we process each message
messages := make([]string, len(filteredMessages)) messages := make([]string, len(filteredMessages))
for i := range filteredMessages { for i := range filteredMessages {
@@ -183,7 +184,7 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro
} }
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt, "multimodal_data_count", len(multimodalData)) "msg", msg, "resume", resume, "prompt", prompt, "multimodal_data_count", len(multimodalData))
payload := models.NewLCPReq(prompt, chatBody.Model, multimodalData, payload := models.NewLCPReq(prompt, chatBody.GetModel(), multimodalData,
defaultLCPProps, chatBody.MakeStopSliceExcluding("", listChatRoles())) defaultLCPProps, chatBody.MakeStopSliceExcluding("", listChatRoles()))
data, err := json.Marshal(payload) data, err := json.Marshal(payload)
if err != nil { if err != nil {
@@ -289,17 +290,17 @@ func (op LCPChat) FormMsg(msg, role string, resume bool) (io.Reader, error) {
newMsg = models.NewRoleMsg(role, msg) newMsg = models.NewRoleMsg(role, msg)
} }
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.AppendMessage(newMsg)
logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role, logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role,
"content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages)) "content_len", len(newMsg.Content), "message_count_after_add", chatBody.GetMessageCount())
} }
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.GetMessages())
// openai /v1/chat does not support custom roles; needs to be user, assistant, system // openai /v1/chat does not support custom roles; needs to be user, assistant, system
// Add persona suffix to the last user message to indicate who the assistant should reply as // Add persona suffix to the last user message to indicate who the assistant should reply as
bodyCopy := &models.ChatBody{ bodyCopy := &models.ChatBody{
Messages: make([]models.RoleMsg, len(filteredMessages)), Messages: make([]models.RoleMsg, len(filteredMessages)),
Model: chatBody.Model, Model: chatBody.GetModel(),
Stream: chatBody.Stream, Stream: chatBody.GetStream(),
} }
for i := range filteredMessages { for i := range filteredMessages {
strippedMsg := *stripThinkingFromMsg(&filteredMessages[i]) strippedMsg := *stripThinkingFromMsg(&filteredMessages[i])
@@ -375,13 +376,13 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
if msg != "" { // otherwise let the bot to continue if msg != "" { // otherwise let the bot to continue
newMsg := models.RoleMsg{Role: role, Content: msg} newMsg := models.RoleMsg{Role: role, Content: msg}
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.AppendMessage(newMsg)
} }
// sending description of the tools and how to use them // sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) chatBody.AppendMessage(models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
} }
filteredMessages, botPersona := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, botPersona := filterMessagesForCurrentCharacter(chatBody.GetMessages())
messages := make([]string, len(filteredMessages)) messages := make([]string, len(filteredMessages))
for i := range filteredMessages { for i := range filteredMessages {
messages[i] = stripThinkingFromMsg(&filteredMessages[i]).ToPrompt() messages[i] = stripThinkingFromMsg(&filteredMessages[i]).ToPrompt()
@@ -394,7 +395,7 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
} }
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt) "msg", msg, "resume", resume, "prompt", prompt)
payload := models.NewDSCompletionReq(prompt, chatBody.Model, payload := models.NewDSCompletionReq(prompt, chatBody.GetModel(),
defaultLCPProps["temp"], defaultLCPProps["temp"],
chatBody.MakeStopSliceExcluding("", listChatRoles())) chatBody.MakeStopSliceExcluding("", listChatRoles()))
data, err := json.Marshal(payload) data, err := json.Marshal(payload)
@@ -448,15 +449,15 @@ func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
if msg != "" { // otherwise let the bot continue if msg != "" { // otherwise let the bot continue
newMsg := models.RoleMsg{Role: role, Content: msg} newMsg := models.RoleMsg{Role: role, Content: msg}
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.AppendMessage(newMsg)
} }
// Create copy of chat body with standardized user role // Create copy of chat body with standardized user role
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.GetMessages())
// Add persona suffix to the last user message to indicate who the assistant should reply as // Add persona suffix to the last user message to indicate who the assistant should reply as
bodyCopy := &models.ChatBody{ bodyCopy := &models.ChatBody{
Messages: make([]models.RoleMsg, len(filteredMessages)), Messages: make([]models.RoleMsg, len(filteredMessages)),
Model: chatBody.Model, Model: chatBody.GetModel(),
Stream: chatBody.Stream, Stream: chatBody.GetStream(),
} }
for i := range filteredMessages { for i := range filteredMessages {
strippedMsg := *stripThinkingFromMsg(&filteredMessages[i]) strippedMsg := *stripThinkingFromMsg(&filteredMessages[i])
@@ -527,13 +528,13 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
if msg != "" { // otherwise let the bot to continue if msg != "" { // otherwise let the bot to continue
newMsg := models.RoleMsg{Role: role, Content: msg} newMsg := models.RoleMsg{Role: role, Content: msg}
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.AppendMessage(newMsg)
} }
// sending description of the tools and how to use them // sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) chatBody.AppendMessage(models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
} }
filteredMessages, botPersona := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, botPersona := filterMessagesForCurrentCharacter(chatBody.GetMessages())
messages := make([]string, len(filteredMessages)) messages := make([]string, len(filteredMessages))
for i := range filteredMessages { for i := range filteredMessages {
messages[i] = stripThinkingFromMsg(&filteredMessages[i]).ToPrompt() messages[i] = stripThinkingFromMsg(&filteredMessages[i]).ToPrompt()
@@ -547,7 +548,7 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
stopSlice := chatBody.MakeStopSliceExcluding("", listChatRoles()) stopSlice := chatBody.MakeStopSliceExcluding("", listChatRoles())
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt, "stop_strings", stopSlice) "msg", msg, "resume", resume, "prompt", prompt, "stop_strings", stopSlice)
payload := models.NewOpenRouterCompletionReq(chatBody.Model, prompt, payload := models.NewOpenRouterCompletionReq(chatBody.GetModel(), prompt,
defaultLCPProps, stopSlice) defaultLCPProps, stopSlice)
data, err := json.Marshal(payload) data, err := json.Marshal(payload)
if err != nil { if err != nil {
@@ -633,15 +634,15 @@ func (or OpenRouterChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = models.NewRoleMsg(role, msg) newMsg = models.NewRoleMsg(role, msg)
} }
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.AppendMessage(newMsg)
} }
// Create copy of chat body with standardized user role // Create copy of chat body with standardized user role
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.GetMessages())
// Add persona suffix to the last user message to indicate who the assistant should reply as // Add persona suffix to the last user message to indicate who the assistant should reply as
bodyCopy := &models.ChatBody{ bodyCopy := &models.ChatBody{
Messages: make([]models.RoleMsg, len(filteredMessages)), Messages: make([]models.RoleMsg, len(filteredMessages)),
Model: chatBody.Model, Model: chatBody.GetModel(),
Stream: chatBody.Stream, Stream: chatBody.GetStream(),
} }
for i := range filteredMessages { for i := range filteredMessages {
strippedMsg := *stripThinkingFromMsg(&filteredMessages[i]) strippedMsg := *stripThinkingFromMsg(&filteredMessages[i])

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"sync"
) )
type FuncCall struct { type FuncCall struct {
@@ -639,3 +640,253 @@ type MultimodalToolResp struct {
Type string `json:"type"` Type string `json:"type"`
Parts []map[string]string `json:"parts"` Parts []map[string]string `json:"parts"`
} }
// SafeChatBody is a thread-safe wrapper around ChatBody using RWMutex.
// This allows safe concurrent access to chat state from multiple goroutines.
type SafeChatBody struct {
mu sync.RWMutex
ChatBody
}
// NewSafeChatBody creates a new SafeChatBody from an existing ChatBody.
// If cb is nil, creates an empty ChatBody.
func NewSafeChatBody(cb *ChatBody) *SafeChatBody {
if cb == nil {
return &SafeChatBody{
ChatBody: ChatBody{
Messages: []RoleMsg{},
},
}
}
return &SafeChatBody{
ChatBody: *cb,
}
}
// GetModel returns the model name (thread-safe read).
func (s *SafeChatBody) GetModel() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Model
}
// SetModel sets the model name (thread-safe write).
func (s *SafeChatBody) SetModel(model string) {
s.mu.Lock()
defer s.mu.Unlock()
s.Model = model
}
// GetStream returns the stream flag (thread-safe read).
func (s *SafeChatBody) GetStream() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Stream
}
// SetStream sets the stream flag (thread-safe write).
func (s *SafeChatBody) SetStream(stream bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.Stream = stream
}
// GetMessages returns a copy of all messages (thread-safe read).
// Returns a copy to prevent race conditions after the lock is released.
func (s *SafeChatBody) GetMessages() []RoleMsg {
s.mu.RLock()
defer s.mu.RUnlock()
// Return a copy to prevent external modification
messagesCopy := make([]RoleMsg, len(s.Messages))
copy(messagesCopy, s.Messages)
return messagesCopy
}
// SetMessages replaces all messages (thread-safe write).
func (s *SafeChatBody) SetMessages(messages []RoleMsg) {
s.mu.Lock()
defer s.mu.Unlock()
s.Messages = messages
}
// AppendMessage adds a message to the end (thread-safe write).
func (s *SafeChatBody) AppendMessage(msg RoleMsg) {
s.mu.Lock()
defer s.mu.Unlock()
s.Messages = append(s.Messages, msg)
}
// GetMessageAt returns a message at a specific index (thread-safe read).
// Returns the message and a boolean indicating if the index was valid.
func (s *SafeChatBody) GetMessageAt(index int) (RoleMsg, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if index < 0 || index >= len(s.Messages) {
return RoleMsg{}, false
}
return s.Messages[index], true
}
// SetMessageAt updates a message at a specific index (thread-safe write).
// Returns false if index is out of bounds.
func (s *SafeChatBody) SetMessageAt(index int, msg RoleMsg) bool {
s.mu.Lock()
defer s.mu.Unlock()
if index < 0 || index >= len(s.Messages) {
return false
}
s.Messages[index] = msg
return true
}
// GetLastMessage returns the last message (thread-safe read).
// Returns the message and a boolean indicating if the chat has messages.
func (s *SafeChatBody) GetLastMessage() (RoleMsg, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.Messages) == 0 {
return RoleMsg{}, false
}
return s.Messages[len(s.Messages)-1], true
}
// GetMessageCount returns the number of messages (thread-safe read).
func (s *SafeChatBody) GetMessageCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.Messages)
}
// RemoveLastMessage removes the last message (thread-safe write).
// Returns false if there are no messages.
func (s *SafeChatBody) RemoveLastMessage() bool {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.Messages) == 0 {
return false
}
s.Messages = s.Messages[:len(s.Messages)-1]
return true
}
// TruncateMessages keeps only the first n messages (thread-safe write).
func (s *SafeChatBody) TruncateMessages(n int) {
s.mu.Lock()
defer s.mu.Unlock()
if n < len(s.Messages) {
s.Messages = s.Messages[:n]
}
}
// ClearMessages removes all messages (thread-safe write).
func (s *SafeChatBody) ClearMessages() {
s.mu.Lock()
defer s.mu.Unlock()
s.Messages = []RoleMsg{}
}
// Rename renames all occurrences of oldname to newname in messages (thread-safe read-modify-write).
func (s *SafeChatBody) Rename(oldname, newname string) {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.Messages {
s.Messages[i].Content = strings.ReplaceAll(s.Messages[i].Content, oldname, newname)
s.Messages[i].Role = strings.ReplaceAll(s.Messages[i].Role, oldname, newname)
}
}
// ListRoles returns all unique roles in messages (thread-safe read).
func (s *SafeChatBody) ListRoles() []string {
s.mu.RLock()
defer s.mu.RUnlock()
namesMap := make(map[string]struct{})
for i := range s.Messages {
namesMap[s.Messages[i].Role] = struct{}{}
}
resp := make([]string, len(namesMap))
i := 0
for k := range namesMap {
resp[i] = k
i++
}
return resp
}
// MakeStopSlice returns stop strings for all roles (thread-safe read).
func (s *SafeChatBody) MakeStopSlice() []string {
return s.MakeStopSliceExcluding("", s.ListRoles())
}
// MakeStopSliceExcluding returns stop strings excluding a specific role (thread-safe read).
func (s *SafeChatBody) MakeStopSliceExcluding(excludeRole string, roleList []string) []string {
s.mu.RLock()
defer s.mu.RUnlock()
ss := []string{}
for _, role := range roleList {
if role == excludeRole {
continue
}
ss = append(ss,
role+":\n",
role+":",
role+": ",
role+": ",
role+": \n",
role+": ",
)
}
return ss
}
// UpdateMessageFunc updates a message at index using a provided function.
// The function receives the current message and returns the updated message.
// This is atomic and thread-safe (read-modify-write under single lock).
// Returns false if index is out of bounds.
func (s *SafeChatBody) UpdateMessageFunc(index int, updater func(RoleMsg) RoleMsg) bool {
s.mu.Lock()
defer s.mu.Unlock()
if index < 0 || index >= len(s.Messages) {
return false
}
s.Messages[index] = updater(s.Messages[index])
return true
}
// AppendMessageFunc appends a new message created by a provided function.
// The function receives the current message count and returns the new message.
// This is atomic and thread-safe.
func (s *SafeChatBody) AppendMessageFunc(creator func(count int) RoleMsg) {
s.mu.Lock()
defer s.mu.Unlock()
msg := creator(len(s.Messages))
s.Messages = append(s.Messages, msg)
}
// GetMessagesForLLM returns a filtered copy of messages for sending to LLM.
// This is thread-safe and returns a copy safe for external modification.
func (s *SafeChatBody) GetMessagesForLLM(filterFunc func([]RoleMsg) []RoleMsg) []RoleMsg {
s.mu.RLock()
defer s.mu.RUnlock()
if filterFunc == nil {
messagesCopy := make([]RoleMsg, len(s.Messages))
copy(messagesCopy, s.Messages)
return messagesCopy
}
return filterFunc(s.Messages)
}
// WithLock executes a function while holding the write lock.
// Use this for complex operations that need to be atomic.
func (s *SafeChatBody) WithLock(fn func(*ChatBody)) {
s.mu.Lock()
defer s.mu.Unlock()
fn(&s.ChatBody)
}
// WithRLock executes a function while holding the read lock.
// Use this for complex read-only operations.
func (s *SafeChatBody) WithRLock(fn func(*ChatBody)) {
s.mu.RLock()
defer s.mu.RUnlock()
fn(&s.ChatBody)
}

View File

@@ -22,7 +22,7 @@ func showModelSelectionPopup() {
models, err := fetchLCPModelsWithLoadStatus() models, err := fetchLCPModelsWithLoadStatus()
if err != nil { if err != nil {
logger.Error("failed to fetch models with load status", "error", err) logger.Error("failed to fetch models with load status", "error", err)
return LocalModels return LocalModels.Load().([]string)
} }
return models return models
} }
@@ -30,7 +30,8 @@ func showModelSelectionPopup() {
modelList := getModelListForAPI(cfg.CurrentAPI) modelList := getModelListForAPI(cfg.CurrentAPI)
// Check for empty options list // Check for empty options list
if len(modelList) == 0 { if len(modelList) == 0 {
logger.Warn("empty model list for", "api", cfg.CurrentAPI, "localModelsLen", len(LocalModels), "orModelsLen", len(ORFreeModels)) localModels := LocalModels.Load().([]string)
logger.Warn("empty model list for", "api", cfg.CurrentAPI, "localModelsLen", len(localModels), "orModelsLen", len(ORFreeModels))
var message string var message string
switch { switch {
case strings.Contains(cfg.CurrentAPI, "openrouter.ai"): case strings.Contains(cfg.CurrentAPI, "openrouter.ai"):
@@ -50,7 +51,7 @@ func showModelSelectionPopup() {
// Find the current model index to set as selected // Find the current model index to set as selected
currentModelIndex := -1 currentModelIndex := -1
for i, model := range modelList { for i, model := range modelList {
if strings.TrimPrefix(model, models.LoadedMark) == chatBody.Model { if strings.TrimPrefix(model, models.LoadedMark) == chatBody.GetModel() {
currentModelIndex = i currentModelIndex = i
} }
modelListWidget.AddItem(model, "", 0, nil) modelListWidget.AddItem(model, "", 0, nil)
@@ -61,8 +62,8 @@ func showModelSelectionPopup() {
} }
modelListWidget.SetSelectedFunc(func(index int, mainText string, secondaryText string, shortcut rune) { modelListWidget.SetSelectedFunc(func(index int, mainText string, secondaryText string, shortcut rune) {
modelName := strings.TrimPrefix(mainText, models.LoadedMark) modelName := strings.TrimPrefix(mainText, models.LoadedMark)
chatBody.Model = modelName chatBody.SetModel(modelName)
cfg.CurrentModel = chatBody.Model cfg.CurrentModel = chatBody.GetModel()
pages.RemovePage("modelSelectionPopup") pages.RemovePage("modelSelectionPopup")
app.SetFocus(textArea) app.SetFocus(textArea)
updateCachedModelColor() updateCachedModelColor()
@@ -150,15 +151,13 @@ func showAPILinkSelectionPopup() {
} }
// Assume local llama.cpp // Assume local llama.cpp
refreshLocalModelsIfEmpty() refreshLocalModelsIfEmpty()
localModelsMu.RLock() return LocalModels.Load().([]string)
defer localModelsMu.RUnlock()
return LocalModels
} }
newModelList := getModelListForAPI(cfg.CurrentAPI) newModelList := getModelListForAPI(cfg.CurrentAPI)
// Ensure chatBody.Model is in the new list; if not, set to first available model // Ensure chatBody.Model is in the new list; if not, set to first available model
if len(newModelList) > 0 && !slices.Contains(newModelList, chatBody.Model) { if len(newModelList) > 0 && !slices.Contains(newModelList, chatBody.GetModel()) {
chatBody.Model = strings.TrimPrefix(newModelList[0], models.LoadedMark) chatBody.SetModel(strings.TrimPrefix(newModelList[0], models.LoadedMark))
cfg.CurrentModel = chatBody.Model cfg.CurrentModel = chatBody.GetModel()
updateToolCapabilities() updateToolCapabilities()
} }
pages.RemovePage("apiLinkSelectionPopup") pages.RemovePage("apiLinkSelectionPopup")
@@ -229,7 +228,7 @@ func showUserRoleSelectionPopup() {
// Update the user role in config // Update the user role in config
cfg.WriteNextMsgAs = mainText cfg.WriteNextMsgAs = mainText
// role got switch, update textview with character specific context for user // role got switch, update textview with character specific context for user
filtered := filterMessagesForCharacter(chatBody.Messages, mainText) filtered := filterMessagesForCharacter(chatBody.GetMessages(), mainText)
textView.SetText(chatToText(filtered, cfg.ShowSys)) textView.SetText(chatToText(filtered, cfg.ShowSys))
// Remove the popup page // Remove the popup page
pages.RemovePage("userRoleSelectionPopup") pages.RemovePage("userRoleSelectionPopup")

View File

@@ -4,14 +4,11 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/gdamore/tcell/v2" "github.com/gdamore/tcell/v2"
"github.com/rivo/tview" "github.com/rivo/tview"
) )
var _ = sync.RWMutex{}
// Define constants for cell types // Define constants for cell types
const ( const (
CellTypeCheckbox = "checkbox" CellTypeCheckbox = "checkbox"
@@ -157,9 +154,7 @@ func makePropsTable(props map[string]float32) *tview.Table {
} }
// Assume local llama.cpp // Assume local llama.cpp
refreshLocalModelsIfEmpty() refreshLocalModelsIfEmpty()
localModelsMu.RLock() return LocalModels.Load().([]string)
defer localModelsMu.RUnlock()
return LocalModels
} }
// Add input fields // Add input fields
addInputRow("New char to write msg as", "", func(text string) { addInputRow("New char to write msg as", "", func(text string) {
@@ -262,7 +257,8 @@ func makePropsTable(props map[string]float32) *tview.Table {
// Check for empty options list // Check for empty options list
if len(data.Options) == 0 { if len(data.Options) == 0 {
logger.Warn("empty options list for", "label", label, "api", cfg.CurrentAPI, "localModelsLen", len(LocalModels), "orModelsLen", len(ORFreeModels)) localModels := LocalModels.Load().([]string)
logger.Warn("empty options list for", "label", label, "api", cfg.CurrentAPI, "localModelsLen", len(localModels), "orModelsLen", len(ORFreeModels))
message := "No options available for " + label message := "No options available for " + label
if label == "Select a model" { if label == "Select a model" {
switch { switch {

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"gf-lt/config" "gf-lt/config"
"gf-lt/models" "gf-lt/models"
"gf-lt/onnx"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
@@ -157,6 +156,43 @@ type ONNXEmbedder struct {
modelPath string modelPath string
} }
var onnxInitOnce sync.Once
var onnxReady bool
var onnxLibPath string
var cudaLibPath string
var onnxLibPaths = []string{
"/usr/lib/libonnxruntime.so",
"/usr/lib/libonnxruntime.so.1.24.2",
"/usr/local/lib/libonnxruntime.so",
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
"/opt/onnxruntime/lib/libonnxruntime.so",
}
var cudaLibPaths = []string{
"/usr/lib/libonnxruntime_providers_cuda.so",
"/usr/local/lib/libonnxruntime_providers_cuda.so",
"/opt/onnxruntime/lib/libonnxruntime_providers_cuda.so",
}
func findONNXLibrary() string {
for _, path := range onnxLibPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func findCUDALibrary() string {
for _, path := range cudaLibPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
// Check if model and tokenizer files exist // Check if model and tokenizer files exist
if _, err := os.Stat(modelPath); err != nil { if _, err := os.Stat(modelPath); err != nil {
@@ -166,16 +202,17 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
return nil, fmt.Errorf("tokenizer not found: %w", err) return nil, fmt.Errorf("tokenizer not found: %w", err)
} }
// Initialize ONNX runtime // Find ONNX library
if err := onnx.Init(); err != nil { onnxLibPath = findONNXLibrary()
return nil, fmt.Errorf("ONNX init failed: %w", err) if onnxLibPath == "" {
} return nil, errors.New("ONNX runtime library not found in standard locations")
if onnx.HasCUDASupport() {
logger.Info("ONNX CUDA support enabled")
} else {
logger.Info("ONNX using CPU fallback")
} }
// Find CUDA provider library (optional)
cudaLibPath = findCUDALibrary()
if cudaLibPath == "" {
fmt.Println("WARNING: CUDA provider library not found, will use CPU")
}
emb := &ONNXEmbedder{ emb := &ONNXEmbedder{
tokenizerPath: tokenizerPath, tokenizerPath: tokenizerPath,
dims: dims, dims: dims,
@@ -202,12 +239,26 @@ func (e *ONNXEmbedder) ensureInitialized() error {
} }
e.tokenizer = tok e.tokenizer = tok
} }
// ONNX runtime already initialized by onnx.Init() in NewONNXEmbedder onnxInitOnce.Do(func() {
if !onnx.IsReady() { onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
e.logger.Error("failed to initialize ONNX runtime", "error", err)
onnxReady = false
return
}
// Register CUDA provider if available
if cudaLibPath != "" {
if err := onnxruntime_go.RegisterExecutionProviderLibrary("CUDA", cudaLibPath); err != nil {
e.logger.Warn("failed to register CUDA provider", "error", err)
}
}
onnxReady = true
})
if !onnxReady {
return errors.New("ONNX runtime not ready") return errors.New("ONNX runtime not ready")
} }
// Create session options // Create session options
opts, err := onnx.NewSessionOptions() opts, err := onnxruntime_go.NewSessionOptions()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session options: %w", err) return fmt.Errorf("failed to create session options: %w", err)
} }
@@ -215,7 +266,27 @@ func (e *ONNXEmbedder) ensureInitialized() error {
_ = opts.Destroy() _ = opts.Destroy()
}() }()
if onnx.HasCUDASupport() { // Try to add CUDA provider
useCUDA := cudaLibPath != ""
if useCUDA {
cudaOpts, err := onnxruntime_go.NewCUDAProviderOptions()
if err != nil {
e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err)
useCUDA = false
} else {
defer func() {
_ = cudaOpts.Destroy()
}()
if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil {
e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err)
useCUDA = false
} else if err := opts.AppendExecutionProviderCUDA(cudaOpts); err != nil {
e.logger.Warn("failed to append CUDA provider, falling back to CPU", "error", err)
useCUDA = false
}
}
}
if useCUDA {
e.logger.Info("Using CUDA for ONNX inference") e.logger.Info("Using CUDA for ONNX inference")
} else { } else {
e.logger.Info("Using CPU for ONNX inference") e.logger.Info("Using CPU for ONNX inference")

View File

@@ -29,7 +29,7 @@ func historyToSJSON(msgs []models.RoleMsg) (string, error) {
} }
func exportChat() error { func exportChat() error {
data, err := json.MarshalIndent(chatBody.Messages, "", " ") data, err := json.MarshalIndent(chatBody.GetMessages(), "", " ")
if err != nil { if err != nil {
return err return err
} }
@@ -54,7 +54,7 @@ func importChat(filename string) error {
if _, ok := chatMap[activeChatName]; !ok { if _, ok := chatMap[activeChatName]; !ok {
addNewChat(activeChatName) addNewChat(activeChatName)
} }
chatBody.Messages = messages chatBody.SetMessages(messages)
cfg.AssistantRole = messages[1].Role cfg.AssistantRole = messages[1].Role
if cfg.AssistantRole == cfg.UserRole { if cfg.AssistantRole == cfg.UserRole {
cfg.AssistantRole = messages[2].Role cfg.AssistantRole = messages[2].Role

View File

@@ -128,8 +128,8 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
pages.RemovePage(historyPage) pages.RemovePage(historyPage)
return return
} }
chatBody.Messages = history chatBody.SetMessages(history)
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
activeChatName = selectedChat activeChatName = selectedChat
pages.RemovePage(historyPage) pages.RemovePage(historyPage)
return return
@@ -149,8 +149,8 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
} }
showToast("chat deleted", selectedChat+" was deleted") showToast("chat deleted", selectedChat+" was deleted")
// load last chat // load last chat
chatBody.Messages = loadOldChatOrGetNew() chatBody.SetMessages(loadOldChatOrGetNew())
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
pages.RemovePage(historyPage) pages.RemovePage(historyPage)
return return
case "update card": case "update card":
@@ -163,16 +163,24 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
showToast("error", "no such card: "+agentName) showToast("error", "no such card: "+agentName)
return return
} }
cc.SysPrompt = chatBody.Messages[0].Content if msg0, ok := chatBody.GetMessageAt(0); ok {
cc.FirstMsg = chatBody.Messages[1].Content cc.SysPrompt = msg0.Content
}
if msg1, ok := chatBody.GetMessageAt(1); ok {
cc.FirstMsg = msg1.Content
}
if err := pngmeta.WriteToPng(cc.ToSpec(cfg.UserRole), cc.FilePath, cc.FilePath); err != nil { if err := pngmeta.WriteToPng(cc.ToSpec(cfg.UserRole), cc.FilePath, cc.FilePath); err != nil {
logger.Error("failed to write charcard", "error", err) logger.Error("failed to write charcard", "error", err)
} }
return return
case "move sysprompt onto 1st msg": case "move sysprompt onto 1st msg":
chatBody.Messages[1].Content = chatBody.Messages[0].Content + chatBody.Messages[1].Content chatBody.WithLock(func(cb *models.ChatBody) {
chatBody.Messages[0].Content = rpDefenitionSysMsg if len(cb.Messages) >= 2 {
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) cb.Messages[1].Content = cb.Messages[0].Content + cb.Messages[1].Content
cb.Messages[0].Content = rpDefenitionSysMsg
}
})
textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
activeChatName = selectedChat activeChatName = selectedChat
pages.RemovePage(historyPage) pages.RemovePage(historyPage)
return return
@@ -563,7 +571,7 @@ func makeAgentTable(agentList []string) *tview.Table {
return return
} }
// replace textview // replace textview
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
updateStatusLine() updateStatusLine()
// sysModal.ClearButtons() // sysModal.ClearButtons()
@@ -732,7 +740,7 @@ func makeImportChatTable(filenames []string) *tview.Table {
colorText() colorText()
updateStatusLine() updateStatusLine()
// redraw the text in text area // redraw the text in text area
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
pages.RemovePage(historyPage) pages.RemovePage(historyPage)
app.SetFocus(textArea) app.SetFocus(textArea)
return return

View File

@@ -1215,11 +1215,11 @@ func isCommandAllowed(command string, args ...string) bool {
} }
func summarizeChat(args map[string]string) []byte { func summarizeChat(args map[string]string) []byte {
if len(chatBody.Messages) == 0 { if chatBody.GetMessageCount() == 0 {
return []byte("No chat history to summarize.") return []byte("No chat history to summarize.")
} }
// Format chat history for the agent // Format chat history for the agent
chatText := chatToText(chatBody.Messages, true) // include system and tool messages chatText := chatToText(chatBody.GetMessages(), true) // include system and tool messages
return []byte(chatText) return []byte(chatText)
} }

56
tui.go
View File

@@ -355,7 +355,7 @@ func init() {
searchResults = nil // Clear search results searchResults = nil // Clear search results
searchResultLengths = nil // Clear search result lengths searchResultLengths = nil // Clear search result lengths
originalTextForSearch = "" originalTextForSearch = ""
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) // Reset text without search regions textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys)) // Reset text without search regions
colorText() // Apply normal chat coloring colorText() // Apply normal chat coloring
} else { } else {
// Original logic if no search is active // Original logic if no search is active
@@ -436,9 +436,11 @@ func init() {
pages.RemovePage(editMsgPage) pages.RemovePage(editMsgPage)
return nil return nil
} }
chatBody.Messages[selectedIndex].SetText(editedMsg) chatBody.WithLock(func(cb *models.ChatBody) {
cb.Messages[selectedIndex].SetText(editedMsg)
})
// change textarea // change textarea
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
pages.RemovePage(editMsgPage) pages.RemovePage(editMsgPage)
editMode = false editMode = false
return nil return nil
@@ -466,9 +468,11 @@ func init() {
pages.RemovePage(roleEditPage) pages.RemovePage(roleEditPage)
return return
} }
if selectedIndex >= 0 && selectedIndex < len(chatBody.Messages) { if selectedIndex >= 0 && selectedIndex < chatBody.GetMessageCount() {
chatBody.Messages[selectedIndex].Role = newRole chatBody.WithLock(func(cb *models.ChatBody) {
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) cb.Messages[selectedIndex].Role = newRole
})
textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
pages.RemovePage(roleEditPage) pages.RemovePage(roleEditPage)
} }
@@ -497,7 +501,7 @@ func init() {
return nil return nil
} }
selectedIndex = siInt selectedIndex = siInt
if len(chatBody.Messages)-1 < selectedIndex || selectedIndex < 0 { if chatBody.GetMessageCount()-1 < selectedIndex || selectedIndex < 0 {
msg := "chosen index is out of bounds, will copy user input" msg := "chosen index is out of bounds, will copy user input"
logger.Warn(msg, "index", selectedIndex) logger.Warn(msg, "index", selectedIndex)
showToast("error", msg) showToast("error", msg)
@@ -507,7 +511,7 @@ func init() {
hideIndexBar() // Hide overlay instead of removing page directly hideIndexBar() // Hide overlay instead of removing page directly
return nil return nil
} }
m := chatBody.Messages[selectedIndex] m := chatBody.GetMessages()[selectedIndex]
switch { switch {
case roleEditMode: case roleEditMode:
hideIndexBar() // Hide overlay first hideIndexBar() // Hide overlay first
@@ -574,7 +578,7 @@ func init() {
searchResults = nil searchResults = nil
searchResultLengths = nil searchResultLengths = nil
originalTextForSearch = "" originalTextForSearch = ""
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
return return
} else { } else {
@@ -632,7 +636,7 @@ func init() {
// //
textArea.SetMovedFunc(updateStatusLine) textArea.SetMovedFunc(updateStatusLine)
updateStatusLine() updateStatusLine()
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
if scrollToEndEnabled { if scrollToEndEnabled {
textView.ScrollToEnd() textView.ScrollToEnd()
@@ -646,7 +650,7 @@ func init() {
if event.Key() == tcell.KeyRune && event.Rune() == '5' && event.Modifiers()&tcell.ModAlt != 0 { if event.Key() == tcell.KeyRune && event.Rune() == '5' && event.Modifiers()&tcell.ModAlt != 0 {
// switch cfg.ShowSys // switch cfg.ShowSys
cfg.ShowSys = !cfg.ShowSys cfg.ShowSys = !cfg.ShowSys
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
} }
if event.Key() == tcell.KeyRune && event.Rune() == '3' && event.Modifiers()&tcell.ModAlt != 0 { if event.Key() == tcell.KeyRune && event.Rune() == '3' && event.Modifiers()&tcell.ModAlt != 0 {
@@ -679,7 +683,7 @@ func init() {
// Handle Alt+T to toggle thinking block visibility // Handle Alt+T to toggle thinking block visibility
if event.Key() == tcell.KeyRune && event.Rune() == 't' && event.Modifiers()&tcell.ModAlt != 0 { if event.Key() == tcell.KeyRune && event.Rune() == 't' && event.Modifiers()&tcell.ModAlt != 0 {
thinkingCollapsed = !thinkingCollapsed thinkingCollapsed = !thinkingCollapsed
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
status := "expanded" status := "expanded"
if thinkingCollapsed { if thinkingCollapsed {
@@ -691,7 +695,7 @@ func init() {
// Handle Ctrl+T to toggle tool call/response visibility // Handle Ctrl+T to toggle tool call/response visibility
if event.Key() == tcell.KeyCtrlT { if event.Key() == tcell.KeyCtrlT {
toolCollapsed = !toolCollapsed toolCollapsed = !toolCollapsed
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
colorText() colorText()
status := "expanded" status := "expanded"
if toolCollapsed { if toolCollapsed {
@@ -734,14 +738,14 @@ func init() {
} }
if event.Key() == tcell.KeyF2 && !botRespMode { if event.Key() == tcell.KeyF2 && !botRespMode {
// regen last msg // regen last msg
if len(chatBody.Messages) == 0 { if chatBody.GetMessageCount() == 0 {
showToast("info", "no messages to regenerate") showToast("info", "no messages to regenerate")
return nil return nil
} }
chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] chatBody.TruncateMessages(chatBody.GetMessageCount() - 1)
// there is no case where user msg is regenerated // there is no case where user msg is regenerated
// lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role // lastRole := chatBody.GetMessages()[chatBody.GetMessageCount()-1].Role
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
// go chatRound("", cfg.UserRole, textView, true, false) // go chatRound("", cfg.UserRole, textView, true, false)
if cfg.TTS_ENABLED { if cfg.TTS_ENABLED {
TTSDoneChan <- true TTSDoneChan <- true
@@ -760,12 +764,12 @@ func init() {
colorText() colorText()
return nil return nil
} }
if len(chatBody.Messages) == 0 { if chatBody.GetMessageCount() == 0 {
showToast("info", "no messages to delete") showToast("info", "no messages to delete")
return nil return nil
} }
chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] chatBody.TruncateMessages(chatBody.GetMessageCount() - 1)
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) textView.SetText(chatToText(chatBody.GetMessages(), cfg.ShowSys))
if cfg.TTS_ENABLED { if cfg.TTS_ENABLED {
TTSDoneChan <- true TTSDoneChan <- true
} }
@@ -813,7 +817,7 @@ func init() {
if event.Key() == tcell.KeyF7 { if event.Key() == tcell.KeyF7 {
// copy msg to clipboard // copy msg to clipboard
editMode = false editMode = false
m := chatBody.Messages[len(chatBody.Messages)-1] m := chatBody.GetMessages()[chatBody.GetMessageCount()-1]
msgText := m.GetText() msgText := m.GetText()
if err := copyToClipboard(msgText); err != nil { if err := copyToClipboard(msgText); err != nil {
logger.Error("failed to copy to clipboard", "error", err) logger.Error("failed to copy to clipboard", "error", err)
@@ -997,10 +1001,10 @@ func init() {
TTSDoneChan <- true TTSDoneChan <- true
} }
if event.Key() == tcell.KeyRune && event.Rune() == '0' && event.Modifiers()&tcell.ModAlt != 0 && cfg.TTS_ENABLED { if event.Key() == tcell.KeyRune && event.Rune() == '0' && event.Modifiers()&tcell.ModAlt != 0 && cfg.TTS_ENABLED {
if len(chatBody.Messages) > 0 { if chatBody.GetMessageCount() > 0 {
// Stop any currently playing TTS first // Stop any currently playing TTS first
TTSDoneChan <- true TTSDoneChan <- true
lastMsg := chatBody.Messages[len(chatBody.Messages)-1] lastMsg := chatBody.GetMessages()[chatBody.GetMessageCount()-1]
cleanedText := models.CleanText(lastMsg.GetText()) cleanedText := models.CleanText(lastMsg.GetText())
if cleanedText != "" { if cleanedText != "" {
// nolint: errcheck // nolint: errcheck
@@ -1012,7 +1016,7 @@ func init() {
if event.Key() == tcell.KeyCtrlW { if event.Key() == tcell.KeyCtrlW {
// INFO: continue bot/text message // INFO: continue bot/text message
// without new role // without new role
lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role lastRole := chatBody.GetMessages()[chatBody.GetMessageCount()-1].Role
// go chatRound("", lastRole, textView, false, true) // go chatRound("", lastRole, textView, false, true)
chatRoundChan <- &models.ChatRoundReq{Role: lastRole, Resume: true} chatRoundChan <- &models.ChatRoundReq{Role: lastRole, Resume: true}
return nil return nil
@@ -1098,7 +1102,7 @@ func init() {
if event.Key() == tcell.KeyRune && event.Modifiers() == tcell.ModAlt && event.Rune() == '9' { if event.Key() == tcell.KeyRune && event.Modifiers() == tcell.ModAlt && event.Rune() == '9' {
// Warm up (load) the currently selected model // Warm up (load) the currently selected model
go warmUpModel() go warmUpModel()
showToast("model warmup", "loading model: "+chatBody.Model) showToast("model warmup", "loading model: "+chatBody.GetModel())
return nil return nil
} }
// cannot send msg in editMode or botRespMode // cannot send msg in editMode or botRespMode
@@ -1137,7 +1141,7 @@ func init() {
} }
// add user icon before user msg // add user icon before user msg
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
nl, len(chatBody.Messages), persona, msgText) nl, chatBody.GetMessageCount(), persona, msgText)
textArea.SetText("", true) textArea.SetText("", true)
if scrollToEndEnabled { if scrollToEndEnabled {
textView.ScrollToEnd() textView.ScrollToEnd()