Compare commits
25 Commits
feat/resp-
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d51c5d0f3 | ||
|
|
b97cd67d72 | ||
|
|
888c9fec65 | ||
|
|
4f07994bdc | ||
|
|
776fd7a2c4 | ||
|
|
9c6b0dc1fa | ||
|
|
9f51bd3853 | ||
|
|
b386c1181f | ||
|
|
b8e7649e69 | ||
|
|
6664c1a0fc | ||
|
|
e0c3fe554f | ||
|
|
40943ff4d3 | ||
|
|
6c03a1a277 | ||
|
|
27288e2aaa | ||
|
|
1c728ec7a7 | ||
|
|
78059083c2 | ||
|
|
34cd4ac141 | ||
|
|
343366b12d | ||
|
|
978369eeaa | ||
|
|
c39e1c267d | ||
|
|
9af21895c6 | ||
|
|
e3bd6f219f | ||
|
|
ae62c2c8d8 | ||
|
|
04db7c2f01 | ||
|
|
3d889e70b5 |
13
Makefile
13
Makefile
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: setconfig run lint setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run
|
.PHONY: setconfig run lint install-linters setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run installdelve checkdelve
|
||||||
|
|
||||||
run: setconfig
|
run: setconfig
|
||||||
go build -tags extra -o gf-lt && ./gf-lt
|
go build -tags extra -o gf-lt && ./gf-lt
|
||||||
@@ -15,8 +15,17 @@ noextra-run: setconfig
|
|||||||
setconfig:
|
setconfig:
|
||||||
find config.toml &>/dev/null || cp config.example.toml config.toml
|
find config.toml &>/dev/null || cp config.example.toml config.toml
|
||||||
|
|
||||||
|
installdelve:
|
||||||
|
go install github.com/go-delve/delve/cmd/dlv@latest
|
||||||
|
|
||||||
|
checkdelve:
|
||||||
|
which dlv &>/dev/null || installdelve
|
||||||
|
|
||||||
|
install-linters: ## Install additional linters (noblanks)
|
||||||
|
go install github.com/GrailFinder/noblanks-linter/cmd/noblanks@latest
|
||||||
|
|
||||||
lint: ## Run linters. Use make install-linters first.
|
lint: ## Run linters. Use make install-linters first.
|
||||||
golangci-lint run -c .golangci.yml ./...
|
golangci-lint run -c .golangci.yml ./...; noblanks ./...
|
||||||
|
|
||||||
# Whisper STT Setup (in batteries directory)
|
# Whisper STT Setup (in batteries directory)
|
||||||
setup-whisper: build-whisper download-whisper-model
|
setup-whisper: build-whisper download-whisper-model
|
||||||
|
|||||||
@@ -140,7 +140,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
ag.log.Error("failed to read request body", "error", err)
|
ag.log.Error("failed to read request body", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
|
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("failed to create request", "error", err)
|
ag.log.Error("failed to create request", "error", err)
|
||||||
@@ -150,22 +149,18 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
|
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
|
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
responseBytes, err := io.ReadAll(resp.Body)
|
responseBytes, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("failed to read response", "error", err)
|
ag.log.Error("failed to read response", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
|
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
|
||||||
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
|
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
|
||||||
@@ -178,7 +173,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
// Return raw response as fallback
|
// Return raw response as fallback
|
||||||
return responseBytes, nil
|
return responseBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte(text), nil
|
return []byte(text), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
143
bot.go
143
bot.go
@@ -23,8 +23,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/neurosnap/sentences/english"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -119,7 +117,7 @@ func processMessageTag(msg *models.RoleMsg) *models.RoleMsg {
|
|||||||
}
|
}
|
||||||
// If KnownTo already set, assume tag already processed (content cleaned).
|
// If KnownTo already set, assume tag already processed (content cleaned).
|
||||||
// However, we still check for new tags (maybe added later).
|
// However, we still check for new tags (maybe added later).
|
||||||
knownTo := parseKnownToTag(msg.Content)
|
knownTo := parseKnownToTag(msg.GetText())
|
||||||
// If tag found, replace KnownTo with new list (merge with existing?)
|
// If tag found, replace KnownTo with new list (merge with existing?)
|
||||||
// For simplicity, if knownTo is not nil, replace.
|
// For simplicity, if knownTo is not nil, replace.
|
||||||
if knownTo == nil {
|
if knownTo == nil {
|
||||||
@@ -411,14 +409,21 @@ func fetchLCPModelsWithLoadStatus() ([]string, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result := make([]string, 0, len(models.Data))
|
result := make([]string, 0, len(models.Data))
|
||||||
for _, m := range models.Data {
|
li := 0 // loaded index
|
||||||
|
for i, m := range models.Data {
|
||||||
modelName := m.ID
|
modelName := m.ID
|
||||||
if m.Status.Value == "loaded" {
|
if m.Status.Value == "loaded" {
|
||||||
modelName = "(loaded) " + modelName
|
modelName = "(loaded) " + modelName
|
||||||
|
li = i
|
||||||
}
|
}
|
||||||
result = append(result, modelName)
|
result = append(result, modelName)
|
||||||
}
|
}
|
||||||
return result, nil
|
if li == 0 {
|
||||||
|
return result, nil // no loaded models
|
||||||
|
}
|
||||||
|
loadedModel := result[li]
|
||||||
|
result = append(result[:li], result[li+1:]...)
|
||||||
|
return slices.Concat([]string{loadedModel}, result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchLCPModelsWithStatus returns the full LCPModels struct including status information.
|
// fetchLCPModelsWithStatus returns the full LCPModels struct including status information.
|
||||||
@@ -569,7 +574,6 @@ func sendMsgToLLM(body io.Reader) {
|
|||||||
streamDone <- true
|
streamDone <- true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the initial response is an error before starting to stream
|
// Check if the initial response is an error before starting to stream
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
// Read the response body to get detailed error information
|
// Read the response body to get detailed error information
|
||||||
@@ -584,7 +588,6 @@ func sendMsgToLLM(body io.Reader) {
|
|||||||
streamDone <- true
|
streamDone <- true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the error response for detailed information
|
// Parse the error response for detailed information
|
||||||
detailedError := extractDetailedErrorFromBytes(bodyBytes, resp.StatusCode)
|
detailedError := extractDetailedErrorFromBytes(bodyBytes, resp.StatusCode)
|
||||||
logger.Error("API returned error status", "status_code", resp.StatusCode, "detailed_error", detailedError)
|
logger.Error("API returned error status", "status_code", resp.StatusCode, "detailed_error", detailedError)
|
||||||
@@ -710,7 +713,6 @@ func sendMsgToLLM(body io.Reader) {
|
|||||||
tokenCount++
|
tokenCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// When we get content and have been streaming reasoning, close the thinking block
|
// When we get content and have been streaming reasoning, close the thinking block
|
||||||
if chunk.Chunk != "" && hasReasoning && !reasoningSent {
|
if chunk.Chunk != "" && hasReasoning && !reasoningSent {
|
||||||
// Close the thinking block before sending actual content
|
// Close the thinking block before sending actual content
|
||||||
@@ -718,7 +720,6 @@ func sendMsgToLLM(body io.Reader) {
|
|||||||
tokenCount++
|
tokenCount++
|
||||||
reasoningSent = true
|
reasoningSent = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// bot sends way too many \n
|
// bot sends way too many \n
|
||||||
answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n")
|
answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n")
|
||||||
// Accumulate text to check for stop strings that might span across chunks
|
// Accumulate text to check for stop strings that might span across chunks
|
||||||
@@ -750,68 +751,6 @@ func sendMsgToLLM(body io.Reader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatRagUse(qText string) (string, error) {
|
|
||||||
logger.Debug("Starting RAG query", "original_query", qText)
|
|
||||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to create sentence tokenizer", "error", err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// this where llm should find the questions in text and ask them
|
|
||||||
questionsS := tokenizer.Tokenize(qText)
|
|
||||||
questions := make([]string, len(questionsS))
|
|
||||||
for i, q := range questionsS {
|
|
||||||
questions[i] = q.Text
|
|
||||||
logger.Debug("RAG question extracted", "index", i, "question", q.Text)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(questions) == 0 {
|
|
||||||
logger.Warn("No questions extracted from query text", "query", qText)
|
|
||||||
return "No related results from RAG vector storage.", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
respVecs := []models.VectorRow{}
|
|
||||||
for i, q := range questions {
|
|
||||||
logger.Debug("Processing RAG question", "index", i, "question", q)
|
|
||||||
emb, err := ragger.LineToVector(q)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to get embeddings for RAG", "error", err, "index", i, "question", q)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Debug("Got embeddings for question", "index", i, "question_len", len(q), "embedding_len", len(emb))
|
|
||||||
|
|
||||||
// Create EmbeddingResp struct for the search
|
|
||||||
embeddingResp := &models.EmbeddingResp{
|
|
||||||
Embedding: emb,
|
|
||||||
Index: 0, // Not used in search but required for the struct
|
|
||||||
}
|
|
||||||
vecs, err := ragger.SearchEmb(embeddingResp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to query embeddings in RAG", "error", err, "index", i, "question", q)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Debug("RAG search returned vectors", "index", i, "question", q, "vector_count", len(vecs))
|
|
||||||
respVecs = append(respVecs, vecs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// get raw text
|
|
||||||
resps := []string{}
|
|
||||||
logger.Debug("RAG query final results", "total_vecs_found", len(respVecs))
|
|
||||||
for _, rv := range respVecs {
|
|
||||||
resps = append(resps, rv.RawText)
|
|
||||||
logger.Debug("RAG result", "slug", rv.Slug, "filename", rv.FileName, "raw_text_len", len(rv.RawText))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resps) == 0 {
|
|
||||||
logger.Info("No RAG results found for query", "original_query", qText, "question_count", len(questions))
|
|
||||||
return "No related results from RAG vector storage.", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result := strings.Join(resps, "\n")
|
|
||||||
logger.Debug("RAG query completed", "result_len", len(result), "response_count", len(resps))
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func roleToIcon(role string) string {
|
func roleToIcon(role string) string {
|
||||||
return "<" + role + ">: "
|
return "<" + role + ">: "
|
||||||
}
|
}
|
||||||
@@ -829,14 +768,46 @@ func chatWatcher(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inpired by https://github.com/rivo/tview/issues/225
|
||||||
|
func showSpinner() {
|
||||||
|
spinners := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||||
|
var i int
|
||||||
|
botPersona := cfg.AssistantRole
|
||||||
|
if cfg.WriteNextMsgAsCompletionAgent != "" {
|
||||||
|
botPersona = cfg.WriteNextMsgAsCompletionAgent
|
||||||
|
}
|
||||||
|
for botRespMode || toolRunningMode {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
spin := i % len(spinners)
|
||||||
|
app.QueueUpdateDraw(func() {
|
||||||
|
switch {
|
||||||
|
case toolRunningMode:
|
||||||
|
textArea.SetTitle(spinners[spin] + " tool")
|
||||||
|
case botRespMode:
|
||||||
|
textArea.SetTitle(spinners[spin] + " " + botPersona)
|
||||||
|
default:
|
||||||
|
textArea.SetTitle(spinners[spin] + " input")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
app.QueueUpdateDraw(func() {
|
||||||
|
textArea.SetTitle("input")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func chatRound(r *models.ChatRoundReq) error {
|
func chatRound(r *models.ChatRoundReq) error {
|
||||||
botRespMode = true
|
botRespMode = true
|
||||||
|
go showSpinner()
|
||||||
updateStatusLine()
|
updateStatusLine()
|
||||||
botPersona := cfg.AssistantRole
|
botPersona := cfg.AssistantRole
|
||||||
if cfg.WriteNextMsgAsCompletionAgent != "" {
|
if cfg.WriteNextMsgAsCompletionAgent != "" {
|
||||||
botPersona = cfg.WriteNextMsgAsCompletionAgent
|
botPersona = cfg.WriteNextMsgAsCompletionAgent
|
||||||
}
|
}
|
||||||
defer func() { botRespMode = false }()
|
defer func() {
|
||||||
|
botRespMode = false
|
||||||
|
ClearImageAttachment()
|
||||||
|
}()
|
||||||
// check that there is a model set to use if is not local
|
// check that there is a model set to use if is not local
|
||||||
choseChunkParser()
|
choseChunkParser()
|
||||||
reader, err := chunkParser.FormMsg(r.UserMsg, r.Role, r.Resume)
|
reader, err := chunkParser.FormMsg(r.UserMsg, r.Role, r.Resume)
|
||||||
@@ -855,13 +826,14 @@ func chatRound(r *models.ChatRoundReq) error {
|
|||||||
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{
|
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{
|
||||||
Role: botPersona, Content: "",
|
Role: botPersona, Content: "",
|
||||||
})
|
})
|
||||||
fmt.Fprintf(textView, "\n[-:-:b](%d) ", msgIdx)
|
nl := "\n\n"
|
||||||
fmt.Fprint(textView, roleToIcon(botPersona))
|
prevText := textView.GetText(true)
|
||||||
fmt.Fprint(textView, "[-:-:-]\n")
|
if strings.HasSuffix(prevText, nl) {
|
||||||
if cfg.ThinkUse && !strings.Contains(cfg.CurrentAPI, "v1") {
|
nl = ""
|
||||||
// fmt.Fprint(textView, "<think>")
|
} else if strings.HasSuffix(prevText, "\n") {
|
||||||
chunkChan <- "<think>"
|
nl = "\n"
|
||||||
}
|
}
|
||||||
|
fmt.Fprintf(textView, "%s[-:-:b](%d) %s[-:-:-]\n", nl, msgIdx, roleToIcon(botPersona))
|
||||||
} else {
|
} else {
|
||||||
msgIdx = len(chatBody.Messages) - 1
|
msgIdx = len(chatBody.Messages) - 1
|
||||||
}
|
}
|
||||||
@@ -1198,7 +1170,11 @@ func findCall(msg, toolCall string) bool {
|
|||||||
chatRoundChan <- crr
|
chatRoundChan <- crr
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
// Show tool call progress indicator before execution
|
||||||
|
fmt.Fprintf(textView, "\n[yellow::i][tool: %s...][-:-:-]", fc.Name)
|
||||||
|
toolRunningMode = true
|
||||||
resp := callToolWithAgent(fc.Name, fc.Args)
|
resp := callToolWithAgent(fc.Name, fc.Args)
|
||||||
|
toolRunningMode = false
|
||||||
toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting
|
toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting
|
||||||
logger.Info("llm used a tool call", "tool_name", fc.Name, "too_args", fc.Args, "id", fc.ID, "tool_resp", toolMsg)
|
logger.Info("llm used a tool call", "tool_name", fc.Name, "too_args", fc.Args, "id", fc.ID, "tool_resp", toolMsg)
|
||||||
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
|
fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
|
||||||
@@ -1237,7 +1213,6 @@ func chatToTextSlice(messages []models.RoleMsg, showSys bool) []string {
|
|||||||
func chatToText(messages []models.RoleMsg, showSys bool) string {
|
func chatToText(messages []models.RoleMsg, showSys bool) string {
|
||||||
s := chatToTextSlice(messages, showSys)
|
s := chatToTextSlice(messages, showSys)
|
||||||
text := strings.Join(s, "\n")
|
text := strings.Join(s, "\n")
|
||||||
|
|
||||||
// Collapse thinking blocks if enabled
|
// Collapse thinking blocks if enabled
|
||||||
if thinkingCollapsed {
|
if thinkingCollapsed {
|
||||||
text = thinkRE.ReplaceAllStringFunc(text, func(match string) string {
|
text = thinkRE.ReplaceAllStringFunc(text, func(match string) string {
|
||||||
@@ -1261,7 +1236,6 @@ func chatToText(messages []models.RoleMsg, showSys bool) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return text
|
return text
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1272,12 +1246,9 @@ func removeThinking(chatBody *models.ChatBody) {
|
|||||||
if msg.Role == cfg.ToolRole {
|
if msg.Role == cfg.ToolRole {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// find thinking and remove it
|
// find thinking and remove it - use SetText to preserve ContentParts
|
||||||
rm := models.RoleMsg{
|
msg.SetText(thinkRE.ReplaceAllString(msg.GetText(), ""))
|
||||||
Role: msg.Role,
|
msgs = append(msgs, msg)
|
||||||
Content: thinkRE.ReplaceAllString(msg.Content, ""),
|
|
||||||
}
|
|
||||||
msgs = append(msgs, rm)
|
|
||||||
}
|
}
|
||||||
chatBody.Messages = msgs
|
chatBody.Messages = msgs
|
||||||
}
|
}
|
||||||
|
|||||||
34
bot_test.go
34
bot_test.go
@@ -1,12 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gf-lt/config"
|
"gf-lt/config"
|
||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
||||||
// Mock config for testing
|
// Mock config for testing
|
||||||
testCfg := &config.Config{
|
testCfg := &config.Config{
|
||||||
@@ -14,7 +12,6 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
|||||||
WriteNextMsgAsCompletionAgent: "",
|
WriteNextMsgAsCompletionAgent: "",
|
||||||
}
|
}
|
||||||
cfg = testCfg
|
cfg = testCfg
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input []models.RoleMsg
|
input []models.RoleMsg
|
||||||
@@ -114,38 +111,31 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := consolidateAssistantMessages(tt.input)
|
result := consolidateAssistantMessages(tt.input)
|
||||||
|
|
||||||
if len(result) != len(tt.expected) {
|
if len(result) != len(tt.expected) {
|
||||||
t.Errorf("Expected %d messages, got %d", len(tt.expected), len(result))
|
t.Errorf("Expected %d messages, got %d", len(tt.expected), len(result))
|
||||||
t.Logf("Result: %+v", result)
|
t.Logf("Result: %+v", result)
|
||||||
t.Logf("Expected: %+v", tt.expected)
|
t.Logf("Expected: %+v", tt.expected)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, expectedMsg := range tt.expected {
|
for i, expectedMsg := range tt.expected {
|
||||||
if i >= len(result) {
|
if i >= len(result) {
|
||||||
t.Errorf("Result has fewer messages than expected at index %d", i)
|
t.Errorf("Result has fewer messages than expected at index %d", i)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
actualMsg := result[i]
|
actualMsg := result[i]
|
||||||
if actualMsg.Role != expectedMsg.Role {
|
if actualMsg.Role != expectedMsg.Role {
|
||||||
t.Errorf("Message %d: expected role '%s', got '%s'", i, expectedMsg.Role, actualMsg.Role)
|
t.Errorf("Message %d: expected role '%s', got '%s'", i, expectedMsg.Role, actualMsg.Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if actualMsg.Content != expectedMsg.Content {
|
if actualMsg.Content != expectedMsg.Content {
|
||||||
t.Errorf("Message %d: expected content '%s', got '%s'", i, expectedMsg.Content, actualMsg.Content)
|
t.Errorf("Message %d: expected content '%s', got '%s'", i, expectedMsg.Content, actualMsg.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
if actualMsg.ToolCallID != expectedMsg.ToolCallID {
|
if actualMsg.ToolCallID != expectedMsg.ToolCallID {
|
||||||
t.Errorf("Message %d: expected ToolCallID '%s', got '%s'", i, expectedMsg.ToolCallID, actualMsg.ToolCallID)
|
t.Errorf("Message %d: expected ToolCallID '%s', got '%s'", i, expectedMsg.ToolCallID, actualMsg.ToolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Additional check: ensure no messages were lost
|
// Additional check: ensure no messages were lost
|
||||||
if !reflect.DeepEqual(result, tt.expected) {
|
if !reflect.DeepEqual(result, tt.expected) {
|
||||||
t.Errorf("Result does not match expected:\nResult: %+v\nExpected: %+v", result, tt.expected)
|
t.Errorf("Result does not match expected:\nResult: %+v\nExpected: %+v", result, tt.expected)
|
||||||
@@ -153,7 +143,6 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalFuncCall(t *testing.T) {
|
func TestUnmarshalFuncCall(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -213,7 +202,6 @@ func TestUnmarshalFuncCall(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := unmarshalFuncCall(tt.jsonStr)
|
got, err := unmarshalFuncCall(tt.jsonStr)
|
||||||
@@ -238,7 +226,6 @@ func TestUnmarshalFuncCall(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertJSONToMapStringString(t *testing.T) {
|
func TestConvertJSONToMapStringString(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -265,7 +252,6 @@ func TestConvertJSONToMapStringString(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := convertJSONToMapStringString(tt.jsonStr)
|
got, err := convertJSONToMapStringString(tt.jsonStr)
|
||||||
@@ -287,7 +273,6 @@ func TestConvertJSONToMapStringString(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseKnownToTag(t *testing.T) {
|
func TestParseKnownToTag(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -378,7 +363,6 @@ func TestParseKnownToTag(t *testing.T) {
|
|||||||
wantKnownTo: []string{"Alice", "Bob", "Carl"},
|
wantKnownTo: []string{"Alice", "Bob", "Carl"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Set up config
|
// Set up config
|
||||||
@@ -402,7 +386,6 @@ func TestParseKnownToTag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessMessageTag(t *testing.T) {
|
func TestProcessMessageTag(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -498,7 +481,6 @@ func TestProcessMessageTag(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testCfg := &config.Config{
|
testCfg := &config.Config{
|
||||||
@@ -529,7 +511,6 @@ func TestProcessMessageTag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilterMessagesForCharacter(t *testing.T) {
|
func TestFilterMessagesForCharacter(t *testing.T) {
|
||||||
messages := []models.RoleMsg{
|
messages := []models.RoleMsg{
|
||||||
{Role: "system", Content: "System message", KnownTo: nil}, // visible to all
|
{Role: "system", Content: "System message", KnownTo: nil}, // visible to all
|
||||||
@@ -539,7 +520,6 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
{Role: "Alice", Content: "Private to Carl", KnownTo: []string{"Alice", "Carl"}},
|
{Role: "Alice", Content: "Private to Carl", KnownTo: []string{"Alice", "Carl"}},
|
||||||
{Role: "Carl", Content: "Hi all", KnownTo: nil}, // visible to all
|
{Role: "Carl", Content: "Hi all", KnownTo: nil}, // visible to all
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
enabled bool
|
enabled bool
|
||||||
@@ -583,7 +563,6 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
wantIndices: []int{0, 1, 5},
|
wantIndices: []int{0, 1, 5},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testCfg := &config.Config{
|
testCfg := &config.Config{
|
||||||
@@ -591,15 +570,12 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
CharSpecificContextTag: "@",
|
CharSpecificContextTag: "@",
|
||||||
}
|
}
|
||||||
cfg = testCfg
|
cfg = testCfg
|
||||||
|
|
||||||
got := filterMessagesForCharacter(messages, tt.character)
|
got := filterMessagesForCharacter(messages, tt.character)
|
||||||
|
|
||||||
if len(got) != len(tt.wantIndices) {
|
if len(got) != len(tt.wantIndices) {
|
||||||
t.Errorf("filterMessagesForCharacter() returned %d messages, want %d", len(got), len(tt.wantIndices))
|
t.Errorf("filterMessagesForCharacter() returned %d messages, want %d", len(got), len(tt.wantIndices))
|
||||||
t.Logf("got: %v", got)
|
t.Logf("got: %v", got)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, idx := range tt.wantIndices {
|
for i, idx := range tt.wantIndices {
|
||||||
if got[i].Content != messages[idx].Content {
|
if got[i].Content != messages[idx].Content {
|
||||||
t.Errorf("filterMessagesForCharacter() message %d content = %q, want %q", i, got[i].Content, messages[idx].Content)
|
t.Errorf("filterMessagesForCharacter() message %d content = %q, want %q", i, got[i].Content, messages[idx].Content)
|
||||||
@@ -608,7 +584,6 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
||||||
// Test that the Copy() method preserves the KnownTo field
|
// Test that the Copy() method preserves the KnownTo field
|
||||||
originalMsg := models.RoleMsg{
|
originalMsg := models.RoleMsg{
|
||||||
@@ -616,9 +591,7 @@ func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
|||||||
Content: "Test message",
|
Content: "Test message",
|
||||||
KnownTo: []string{"Bob", "Charlie"},
|
KnownTo: []string{"Bob", "Charlie"},
|
||||||
}
|
}
|
||||||
|
|
||||||
copiedMsg := originalMsg.Copy()
|
copiedMsg := originalMsg.Copy()
|
||||||
|
|
||||||
if copiedMsg.Role != originalMsg.Role {
|
if copiedMsg.Role != originalMsg.Role {
|
||||||
t.Errorf("Copy() failed to preserve Role: got %q, want %q", copiedMsg.Role, originalMsg.Role)
|
t.Errorf("Copy() failed to preserve Role: got %q, want %q", copiedMsg.Role, originalMsg.Role)
|
||||||
}
|
}
|
||||||
@@ -635,7 +608,6 @@ func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
|||||||
t.Errorf("Copy() failed to preserve hasContentParts flag")
|
t.Errorf("Copy() failed to preserve hasContentParts flag")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKnownToFieldPreservationScenario(t *testing.T) {
|
func TestKnownToFieldPreservationScenario(t *testing.T) {
|
||||||
// Test the specific scenario from the log where KnownTo field was getting lost
|
// Test the specific scenario from the log where KnownTo field was getting lost
|
||||||
originalMsg := models.RoleMsg{
|
originalMsg := models.RoleMsg{
|
||||||
@@ -643,28 +615,22 @@ func TestKnownToFieldPreservationScenario(t *testing.T) {
|
|||||||
Content: `Alice: "Okay, Bob. The word is... **'Ephemeral'**. (ooc: @Bob@)"`,
|
Content: `Alice: "Okay, Bob. The word is... **'Ephemeral'**. (ooc: @Bob@)"`,
|
||||||
KnownTo: []string{"Bob"}, // This was detected in the log
|
KnownTo: []string{"Bob"}, // This was detected in the log
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Original message - Role: %s, Content: %s, KnownTo: %v",
|
t.Logf("Original message - Role: %s, Content: %s, KnownTo: %v",
|
||||||
originalMsg.Role, originalMsg.Content, originalMsg.KnownTo)
|
originalMsg.Role, originalMsg.Content, originalMsg.KnownTo)
|
||||||
|
|
||||||
// Simulate what happens when the message gets copied during processing
|
// Simulate what happens when the message gets copied during processing
|
||||||
copiedMsg := originalMsg.Copy()
|
copiedMsg := originalMsg.Copy()
|
||||||
|
|
||||||
t.Logf("Copied message - Role: %s, Content: %s, KnownTo: %v",
|
t.Logf("Copied message - Role: %s, Content: %s, KnownTo: %v",
|
||||||
copiedMsg.Role, copiedMsg.Content, copiedMsg.KnownTo)
|
copiedMsg.Role, copiedMsg.Content, copiedMsg.KnownTo)
|
||||||
|
|
||||||
// Check if KnownTo field survived the copy
|
// Check if KnownTo field survived the copy
|
||||||
if len(copiedMsg.KnownTo) == 0 {
|
if len(copiedMsg.KnownTo) == 0 {
|
||||||
t.Error("ERROR: KnownTo field was lost during copy!")
|
t.Error("ERROR: KnownTo field was lost during copy!")
|
||||||
} else {
|
} else {
|
||||||
t.Log("SUCCESS: KnownTo field was preserved during copy!")
|
t.Log("SUCCESS: KnownTo field was preserved during copy!")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the content is the same
|
// Verify the content is the same
|
||||||
if copiedMsg.Content != originalMsg.Content {
|
if copiedMsg.Content != originalMsg.Content {
|
||||||
t.Errorf("Content was changed during copy: got %s, want %s", copiedMsg.Content, originalMsg.Content)
|
t.Errorf("Content was changed during copy: got %s, want %s", copiedMsg.Content, originalMsg.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the KnownTo slice is properly copied
|
// Verify the KnownTo slice is properly copied
|
||||||
if !reflect.DeepEqual(copiedMsg.KnownTo, originalMsg.KnownTo) {
|
if !reflect.DeepEqual(copiedMsg.KnownTo, originalMsg.KnownTo) {
|
||||||
t.Errorf("KnownTo was not properly copied: got %v, want %v", copiedMsg.KnownTo, originalMsg.KnownTo)
|
t.Errorf("KnownTo was not properly copied: got %v, want %v", copiedMsg.KnownTo, originalMsg.KnownTo)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions"
|
|||||||
# OpenRouterToken = ""
|
# OpenRouterToken = ""
|
||||||
# embeddings
|
# embeddings
|
||||||
EmbedURL = "http://localhost:8082/v1/embeddings"
|
EmbedURL = "http://localhost:8082/v1/embeddings"
|
||||||
HFToken = false
|
HFToken = ""
|
||||||
#
|
#
|
||||||
ShowSys = true
|
ShowSys = true
|
||||||
LogFile = "log.txt"
|
LogFile = "log.txt"
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ type Config struct {
|
|||||||
UserRole string `toml:"UserRole"`
|
UserRole string `toml:"UserRole"`
|
||||||
ToolRole string `toml:"ToolRole"`
|
ToolRole string `toml:"ToolRole"`
|
||||||
ToolUse bool `toml:"ToolUse"`
|
ToolUse bool `toml:"ToolUse"`
|
||||||
ThinkUse bool `toml:"ThinkUse"`
|
|
||||||
StripThinkingFromAPI bool `toml:"StripThinkingFromAPI"`
|
StripThinkingFromAPI bool `toml:"StripThinkingFromAPI"`
|
||||||
ReasoningEffort string `toml:"ReasoningEffort"`
|
ReasoningEffort string `toml:"ReasoningEffort"`
|
||||||
AssistantRole string `toml:"AssistantRole"`
|
AssistantRole string `toml:"AssistantRole"`
|
||||||
@@ -125,6 +124,9 @@ func LoadConfig(fn string) (*Config, error) {
|
|||||||
if config.CompletionAPI != "" {
|
if config.CompletionAPI != "" {
|
||||||
config.ApiLinks = append(config.ApiLinks, config.CompletionAPI)
|
config.ApiLinks = append(config.ApiLinks, config.CompletionAPI)
|
||||||
}
|
}
|
||||||
|
if config.RAGDir == "" {
|
||||||
|
config.RAGDir = "ragimport"
|
||||||
|
}
|
||||||
// if any value is empty fill with default
|
// if any value is empty fill with default
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -165,9 +165,6 @@ Those could be switched in program, but also bould be setup in config.
|
|||||||
#### ToolUse
|
#### ToolUse
|
||||||
- Enable or disable explanation of tools to llm, so it could use them.
|
- Enable or disable explanation of tools to llm, so it could use them.
|
||||||
|
|
||||||
#### ThinkUse
|
|
||||||
- Enable or disable insertion of JsonSerializerToken at the beggining of llm resp.
|
|
||||||
|
|
||||||
### StripThinkingFromAPI (`true`)
|
### StripThinkingFromAPI (`true`)
|
||||||
- Strip thinking blocks from messages before sending to LLM. Keeps them in chat history for local viewing but reduces token usage in API calls.
|
- Strip thinking blocks from messages before sending to LLM. Keeps them in chat history for local viewing but reduces token usage in API calls.
|
||||||
|
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -6,17 +6,19 @@ require (
|
|||||||
github.com/BurntSushi/toml v1.5.0
|
github.com/BurntSushi/toml v1.5.0
|
||||||
github.com/GrailFinder/google-translate-tts v0.1.3
|
github.com/GrailFinder/google-translate-tts v0.1.3
|
||||||
github.com/GrailFinder/searchagent v0.2.0
|
github.com/GrailFinder/searchagent v0.2.0
|
||||||
|
github.com/PuerkitoBio/goquery v1.11.0
|
||||||
github.com/gdamore/tcell/v2 v2.13.2
|
github.com/gdamore/tcell/v2 v2.13.2
|
||||||
github.com/glebarez/go-sqlite v1.22.0
|
github.com/glebarez/go-sqlite v1.22.0
|
||||||
github.com/gopxl/beep/v2 v2.1.1
|
github.com/gopxl/beep/v2 v2.1.1
|
||||||
github.com/gordonklaus/portaudio v0.0.0-20250206071425-98a94950218b
|
github.com/gordonklaus/portaudio v0.0.0-20250206071425-98a94950218b
|
||||||
github.com/jmoiron/sqlx v1.4.0
|
github.com/jmoiron/sqlx v1.4.0
|
||||||
|
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728
|
||||||
github.com/neurosnap/sentences v1.1.2
|
github.com/neurosnap/sentences v1.1.2
|
||||||
github.com/rivo/tview v0.42.0
|
github.com/rivo/tview v0.42.0
|
||||||
|
github.com/yuin/goldmark v1.4.13
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/PuerkitoBio/goquery v1.11.0 // indirect
|
|
||||||
github.com/andybalholm/cascadia v1.3.3 // indirect
|
github.com/andybalholm/cascadia v1.3.3 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/ebitengine/oto/v3 v3.4.0 // indirect
|
github.com/ebitengine/oto/v3 v3.4.0 // indirect
|
||||||
|
|||||||
3
go.sum
3
go.sum
@@ -43,6 +43,8 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs
|
|||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||||
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||||
|
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+ExRDqGQltzXqN/xypdKP86niVn8=
|
||||||
|
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||||
@@ -67,6 +69,7 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
|||||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
|
|||||||
12
helpfuncs.go
12
helpfuncs.go
@@ -32,7 +32,6 @@ func startModelColorUpdater() {
|
|||||||
|
|
||||||
// Initial check
|
// Initial check
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
}
|
}
|
||||||
@@ -75,15 +74,16 @@ func stripThinkingFromMsg(msg *models.RoleMsg) *models.RoleMsg {
|
|||||||
if !cfg.StripThinkingFromAPI {
|
if !cfg.StripThinkingFromAPI {
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
// Skip user, tool, and system messages - they might contain thinking examples
|
// Skip user, tool, they might contain thinking and system messages - examples
|
||||||
if msg.Role == cfg.UserRole || msg.Role == cfg.ToolRole || msg.Role == "system" {
|
if msg.Role == cfg.UserRole || msg.Role == cfg.ToolRole || msg.Role == "system" {
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
// Strip thinking from assistant messages
|
// Strip thinking from assistant messages
|
||||||
if thinkRE.MatchString(msg.Content) {
|
msgText := msg.GetText()
|
||||||
msg.Content = thinkRE.ReplaceAllString(msg.Content, "")
|
if thinkRE.MatchString(msgText) {
|
||||||
// Clean up any double newlines that might result
|
cleanedText := thinkRE.ReplaceAllString(msgText, "")
|
||||||
msg.Content = strings.TrimSpace(msg.Content)
|
cleanedText = strings.TrimSpace(cleanedText)
|
||||||
|
msg.SetText(cleanedText)
|
||||||
}
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|||||||
110
llm.go
110
llm.go
@@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
var imageAttachmentPath string // Global variable to track image attachment for next message
|
var imageAttachmentPath string // Global variable to track image attachment for next message
|
||||||
var lastImg string // for ctrl+j
|
var lastImg string // for ctrl+j
|
||||||
var RAGMsg = "Retrieved context for user's query:\n"
|
|
||||||
|
|
||||||
// containsToolSysMsg checks if the toolSysMsg already exists in the chat body
|
// containsToolSysMsg checks if the toolSysMsg already exists in the chat body
|
||||||
func containsToolSysMsg() bool {
|
func containsToolSysMsg() bool {
|
||||||
@@ -142,22 +141,6 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro
|
|||||||
newMsg = *processMessageTag(&newMsg)
|
newMsg = *processMessageTag(&newMsg)
|
||||||
chatBody.Messages = append(chatBody.Messages, newMsg)
|
chatBody.Messages = append(chatBody.Messages, newMsg)
|
||||||
}
|
}
|
||||||
// if rag - add as system message to avoid conflicts with tool usage
|
|
||||||
if !resume && cfg.RAGEnabled {
|
|
||||||
um := chatBody.Messages[len(chatBody.Messages)-1].Content
|
|
||||||
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
|
|
||||||
ragResp, err := chatRagUse(um)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to form a rag msg", "error", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logger.Debug("RAG response received", "response_len", len(ragResp),
|
|
||||||
"response_preview", ragResp[:min(len(ragResp), 100)])
|
|
||||||
// Use system role for RAG context to avoid conflicts with tool usage
|
|
||||||
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
|
|
||||||
chatBody.Messages = append(chatBody.Messages, ragMsg)
|
|
||||||
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
|
|
||||||
}
|
|
||||||
// sending description of the tools and how to use them
|
// sending description of the tools and how to use them
|
||||||
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
|
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
|
||||||
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
|
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
|
||||||
@@ -184,9 +167,6 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro
|
|||||||
botMsgStart := "\n" + botPersona + ":\n"
|
botMsgStart := "\n" + botPersona + ":\n"
|
||||||
prompt += botMsgStart
|
prompt += botMsgStart
|
||||||
}
|
}
|
||||||
if cfg.ThinkUse && !cfg.ToolUse {
|
|
||||||
prompt += "<think>"
|
|
||||||
}
|
|
||||||
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
||||||
"msg", msg, "resume", resume, "prompt", prompt, "multimodal_data_count", len(multimodalData))
|
"msg", msg, "resume", resume, "prompt", prompt, "multimodal_data_count", len(multimodalData))
|
||||||
payload := models.NewLCPReq(prompt, chatBody.Model, multimodalData,
|
payload := models.NewLCPReq(prompt, chatBody.Model, multimodalData,
|
||||||
@@ -236,13 +216,11 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
|
|||||||
logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data))
|
logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data))
|
||||||
return &models.TextChunk{Finished: true}, nil
|
return &models.TextChunk{Finished: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1]
|
lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1]
|
||||||
resp := &models.TextChunk{
|
resp := &models.TextChunk{
|
||||||
Chunk: lastChoice.Delta.Content,
|
Chunk: lastChoice.Delta.Content,
|
||||||
Reasoning: lastChoice.Delta.ReasoningContent,
|
Reasoning: lastChoice.Delta.ReasoningContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool calls in all choices, not just the last one
|
// Check for tool calls in all choices, not just the last one
|
||||||
for _, choice := range llmchunk.Choices {
|
for _, choice := range llmchunk.Choices {
|
||||||
if len(choice.Delta.ToolCalls) > 0 {
|
if len(choice.Delta.ToolCalls) > 0 {
|
||||||
@@ -257,7 +235,6 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
|
|||||||
break // Process only the first tool call
|
break // Process only the first tool call
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastChoice.FinishReason == "stop" {
|
if lastChoice.FinishReason == "stop" {
|
||||||
if resp.Chunk != "" {
|
if resp.Chunk != "" {
|
||||||
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
|
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
|
||||||
@@ -304,23 +281,6 @@ func (op LCPChat) FormMsg(msg, role string, resume bool) (io.Reader, error) {
|
|||||||
logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role,
|
logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role,
|
||||||
"content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages))
|
"content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages))
|
||||||
}
|
}
|
||||||
// if rag - add as system message to avoid conflicts with tool usage
|
|
||||||
if !resume && cfg.RAGEnabled {
|
|
||||||
um := chatBody.Messages[len(chatBody.Messages)-1].Content
|
|
||||||
logger.Debug("LCPChat: RAG is enabled, preparing RAG context", "user_message", um)
|
|
||||||
ragResp, err := chatRagUse(um)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("LCPChat: failed to form a rag msg", "error", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logger.Debug("LCPChat: RAG response received",
|
|
||||||
"response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
|
|
||||||
// Use system role for RAG context to avoid conflicts with tool usage
|
|
||||||
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
|
|
||||||
chatBody.Messages = append(chatBody.Messages, ragMsg)
|
|
||||||
logger.Debug("LCPChat: RAG message added to chat body", "role", ragMsg.Role,
|
|
||||||
"rag_content_len", len(ragMsg.Content), "message_count_after_rag", len(chatBody.Messages))
|
|
||||||
}
|
|
||||||
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
|
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
|
||||||
// openai /v1/chat does not support custom roles; needs to be user, assistant, system
|
// openai /v1/chat does not support custom roles; needs to be user, assistant, system
|
||||||
// Add persona suffix to the last user message to indicate who the assistant should reply as
|
// Add persona suffix to the last user message to indicate who the assistant should reply as
|
||||||
@@ -392,22 +352,6 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
|
|||||||
newMsg = *processMessageTag(&newMsg)
|
newMsg = *processMessageTag(&newMsg)
|
||||||
chatBody.Messages = append(chatBody.Messages, newMsg)
|
chatBody.Messages = append(chatBody.Messages, newMsg)
|
||||||
}
|
}
|
||||||
// if rag - add as system message to avoid conflicts with tool usage
|
|
||||||
if !resume && cfg.RAGEnabled {
|
|
||||||
um := chatBody.Messages[len(chatBody.Messages)-1].Content
|
|
||||||
logger.Debug("DeepSeekerCompletion: RAG is enabled, preparing RAG context", "user_message", um)
|
|
||||||
ragResp, err := chatRagUse(um)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("DeepSeekerCompletion: failed to form a rag msg", "error", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logger.Debug("DeepSeekerCompletion: RAG response received",
|
|
||||||
"response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
|
|
||||||
// Use system role for RAG context to avoid conflicts with tool usage
|
|
||||||
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
|
|
||||||
chatBody.Messages = append(chatBody.Messages, ragMsg)
|
|
||||||
logger.Debug("DeepSeekerCompletion: RAG message added to chat body", "message_count", len(chatBody.Messages))
|
|
||||||
}
|
|
||||||
// sending description of the tools and how to use them
|
// sending description of the tools and how to use them
|
||||||
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
|
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
|
||||||
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
|
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
|
||||||
@@ -423,9 +367,6 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
|
|||||||
botMsgStart := "\n" + botPersona + ":\n"
|
botMsgStart := "\n" + botPersona + ":\n"
|
||||||
prompt += botMsgStart
|
prompt += botMsgStart
|
||||||
}
|
}
|
||||||
if cfg.ThinkUse && !cfg.ToolUse {
|
|
||||||
prompt += "<think>"
|
|
||||||
}
|
|
||||||
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
||||||
"msg", msg, "resume", resume, "prompt", prompt)
|
"msg", msg, "resume", resume, "prompt", prompt)
|
||||||
payload := models.NewDSCompletionReq(prompt, chatBody.Model,
|
payload := models.NewDSCompletionReq(prompt, chatBody.Model,
|
||||||
@@ -480,22 +421,6 @@ func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
|
|||||||
newMsg = *processMessageTag(&newMsg)
|
newMsg = *processMessageTag(&newMsg)
|
||||||
chatBody.Messages = append(chatBody.Messages, newMsg)
|
chatBody.Messages = append(chatBody.Messages, newMsg)
|
||||||
}
|
}
|
||||||
// if rag - add as system message to avoid conflicts with tool usage
|
|
||||||
if !resume && cfg.RAGEnabled {
|
|
||||||
um := chatBody.Messages[len(chatBody.Messages)-1].Content
|
|
||||||
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
|
|
||||||
ragResp, err := chatRagUse(um)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to form a rag msg", "error", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logger.Debug("RAG response received", "response_len", len(ragResp),
|
|
||||||
"response_preview", ragResp[:min(len(ragResp), 100)])
|
|
||||||
// Use system role for RAG context to avoid conflicts with tool usage
|
|
||||||
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
|
|
||||||
chatBody.Messages = append(chatBody.Messages, ragMsg)
|
|
||||||
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
|
|
||||||
}
|
|
||||||
// Create copy of chat body with standardized user role
|
// Create copy of chat body with standardized user role
|
||||||
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
|
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
|
||||||
// Add persona suffix to the last user message to indicate who the assistant should reply as
|
// Add persona suffix to the last user message to indicate who the assistant should reply as
|
||||||
@@ -558,22 +483,6 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
|
|||||||
newMsg = *processMessageTag(&newMsg)
|
newMsg = *processMessageTag(&newMsg)
|
||||||
chatBody.Messages = append(chatBody.Messages, newMsg)
|
chatBody.Messages = append(chatBody.Messages, newMsg)
|
||||||
}
|
}
|
||||||
// if rag - add as system message to avoid conflicts with tool usage
|
|
||||||
if !resume && cfg.RAGEnabled {
|
|
||||||
um := chatBody.Messages[len(chatBody.Messages)-1].Content
|
|
||||||
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
|
|
||||||
ragResp, err := chatRagUse(um)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to form a rag msg", "error", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logger.Debug("RAG response received", "response_len",
|
|
||||||
len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
|
|
||||||
// Use system role for RAG context to avoid conflicts with tool usage
|
|
||||||
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
|
|
||||||
chatBody.Messages = append(chatBody.Messages, ragMsg)
|
|
||||||
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
|
|
||||||
}
|
|
||||||
// sending description of the tools and how to use them
|
// sending description of the tools and how to use them
|
||||||
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
|
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
|
||||||
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
|
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
|
||||||
@@ -589,9 +498,6 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
|
|||||||
botMsgStart := "\n" + botPersona + ":\n"
|
botMsgStart := "\n" + botPersona + ":\n"
|
||||||
prompt += botMsgStart
|
prompt += botMsgStart
|
||||||
}
|
}
|
||||||
if cfg.ThinkUse && !cfg.ToolUse {
|
|
||||||
prompt += "<think>"
|
|
||||||
}
|
|
||||||
stopSlice := chatBody.MakeStopSliceExcluding("", listChatRoles())
|
stopSlice := chatBody.MakeStopSliceExcluding("", listChatRoles())
|
||||||
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
||||||
"msg", msg, "resume", resume, "prompt", prompt, "stop_strings", stopSlice)
|
"msg", msg, "resume", resume, "prompt", prompt, "stop_strings", stopSlice)
|
||||||
@@ -679,22 +585,6 @@ func (or OpenRouterChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
|
|||||||
newMsg = *processMessageTag(&newMsg)
|
newMsg = *processMessageTag(&newMsg)
|
||||||
chatBody.Messages = append(chatBody.Messages, newMsg)
|
chatBody.Messages = append(chatBody.Messages, newMsg)
|
||||||
}
|
}
|
||||||
// if rag - add as system message to avoid conflicts with tool usage
|
|
||||||
if !resume && cfg.RAGEnabled {
|
|
||||||
um := chatBody.Messages[len(chatBody.Messages)-1].Content
|
|
||||||
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
|
|
||||||
ragResp, err := chatRagUse(um)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to form a rag msg", "error", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logger.Debug("RAG response received", "response_len", len(ragResp),
|
|
||||||
"response_preview", ragResp[:min(len(ragResp), 100)])
|
|
||||||
// Use system role for RAG context to avoid conflicts with tool usage
|
|
||||||
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
|
|
||||||
chatBody.Messages = append(chatBody.Messages, ragMsg)
|
|
||||||
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
|
|
||||||
}
|
|
||||||
// Create copy of chat body with standardized user role
|
// Create copy of chat body with standardized user role
|
||||||
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
|
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
|
||||||
// Add persona suffix to the last user message to indicate who the assistant should reply as
|
// Add persona suffix to the last user message to indicate who the assistant should reply as
|
||||||
|
|||||||
1
main.go
1
main.go
@@ -7,6 +7,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
boolColors = map[bool]string{true: "green", false: "red"}
|
boolColors = map[bool]string{true: "green", false: "red"}
|
||||||
botRespMode = false
|
botRespMode = false
|
||||||
|
toolRunningMode = false
|
||||||
editMode = false
|
editMode = false
|
||||||
roleEditMode = false
|
roleEditMode = false
|
||||||
injectRole = true
|
injectRole = true
|
||||||
|
|||||||
@@ -241,8 +241,7 @@ func (m *RoleMsg) ToText(i int) string {
|
|||||||
}
|
}
|
||||||
finalContent.WriteString(contentStr)
|
finalContent.WriteString(contentStr)
|
||||||
if m.Stats != nil {
|
if m.Stats != nil {
|
||||||
finalContent.WriteString(fmt.Sprintf("\n[gray::i][%d tok, %.1fs, %.1f t/s][-:-:-]",
|
fmt.Fprintf(&finalContent, "\n[gray::i][%d tok, %.1fs, %.1f t/s][-:-:-]", m.Stats.Tokens, m.Stats.Duration, m.Stats.TokensPerSec)
|
||||||
m.Stats.Tokens, m.Stats.Duration, m.Stats.TokensPerSec))
|
|
||||||
}
|
}
|
||||||
textMsg := fmt.Sprintf("[-:-:b]%s[-:-:-]\n%s\n", icon, finalContent.String())
|
textMsg := fmt.Sprintf("[-:-:b]%s[-:-:-]\n%s\n", icon, finalContent.String())
|
||||||
return strings.ReplaceAll(textMsg, "\n\n", "\n")
|
return strings.ReplaceAll(textMsg, "\n\n", "\n")
|
||||||
@@ -330,6 +329,66 @@ func (m *RoleMsg) Copy() RoleMsg {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetText returns the text content of the message, handling both
|
||||||
|
// simple Content and multimodal ContentParts formats.
|
||||||
|
func (m *RoleMsg) GetText() string {
|
||||||
|
if !m.hasContentParts {
|
||||||
|
return m.Content
|
||||||
|
}
|
||||||
|
var textParts []string
|
||||||
|
for _, part := range m.ContentParts {
|
||||||
|
switch p := part.(type) {
|
||||||
|
case TextContentPart:
|
||||||
|
if p.Type == "text" {
|
||||||
|
textParts = append(textParts, p.Text)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
if partType, exists := p["type"]; exists {
|
||||||
|
if partType == "text" {
|
||||||
|
if textVal, textExists := p["text"]; textExists {
|
||||||
|
if textStr, isStr := textVal.(string); isStr {
|
||||||
|
textParts = append(textParts, textStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(textParts, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetText updates the text content of the message. If the message has
|
||||||
|
// ContentParts (multimodal), it updates the text parts while preserving
|
||||||
|
// images. If not, it sets the simple Content field.
|
||||||
|
func (m *RoleMsg) SetText(text string) {
|
||||||
|
if !m.hasContentParts {
|
||||||
|
m.Content = text
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var newParts []any
|
||||||
|
for _, part := range m.ContentParts {
|
||||||
|
switch p := part.(type) {
|
||||||
|
case TextContentPart:
|
||||||
|
if p.Type == "text" {
|
||||||
|
p.Text = text
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
} else {
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
if partType, exists := p["type"]; exists && partType == "text" {
|
||||||
|
p["text"] = text
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
} else {
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
newParts = append(newParts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.ContentParts = newParts
|
||||||
|
}
|
||||||
|
|
||||||
// AddTextPart adds a text content part to the message
|
// AddTextPart adds a text content part to the message
|
||||||
func (m *RoleMsg) AddTextPart(text string) {
|
func (m *RoleMsg) AddTextPart(text string) {
|
||||||
if !m.hasContentParts {
|
if !m.hasContentParts {
|
||||||
@@ -341,7 +400,6 @@ func (m *RoleMsg) AddTextPart(text string) {
|
|||||||
}
|
}
|
||||||
m.hasContentParts = true
|
m.hasContentParts = true
|
||||||
}
|
}
|
||||||
|
|
||||||
textPart := TextContentPart{Type: "text", Text: text}
|
textPart := TextContentPart{Type: "text", Text: text}
|
||||||
m.ContentParts = append(m.ContentParts, textPart)
|
m.ContentParts = append(m.ContentParts, textPart)
|
||||||
}
|
}
|
||||||
@@ -357,7 +415,6 @@ func (m *RoleMsg) AddImagePart(imageURL, imagePath string) {
|
|||||||
}
|
}
|
||||||
m.hasContentParts = true
|
m.hasContentParts = true
|
||||||
}
|
}
|
||||||
|
|
||||||
imagePart := ImageContentPart{
|
imagePart := ImageContentPart{
|
||||||
Type: "image_url",
|
Type: "image_url",
|
||||||
Path: imagePath, // Store the original file path
|
Path: imagePath, // Store the original file path
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRoleMsgToTextWithImages(t *testing.T) {
|
func TestRoleMsgToTextWithImages(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -92,7 +90,6 @@ func TestRoleMsgToTextWithImages(t *testing.T) {
|
|||||||
expected: "[orange::i][image: /old/path/photo.jpg][-:-:-]",
|
expected: "[orange::i][image: /old/path/photo.jpg][-:-:-]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := tt.msg.ToText(tt.index)
|
result := tt.msg.ToText(tt.index)
|
||||||
@@ -110,12 +107,10 @@ func TestRoleMsgToTextWithImages(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractDisplayPath(t *testing.T) {
|
func TestExtractDisplayPath(t *testing.T) {
|
||||||
// Save original base dir
|
// Save original base dir
|
||||||
originalBaseDir := imageBaseDir
|
originalBaseDir := imageBaseDir
|
||||||
defer func() { imageBaseDir = originalBaseDir }()
|
defer func() { imageBaseDir = originalBaseDir }()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
baseDir string
|
baseDir string
|
||||||
@@ -153,7 +148,6 @@ func TestExtractDisplayPath(t *testing.T) {
|
|||||||
expected: "..._that_exceeds_sixty_characters_limit_yes_it_is_very_long.jpg",
|
expected: "..._that_exceeds_sixty_characters_limit_yes_it_is_very_long.jpg",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
imageBaseDir = tt.baseDir
|
imageBaseDir = tt.baseDir
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ func TestORModelsListModels(t *testing.T) {
|
|||||||
t.Errorf("expected 4 total models, got %d", len(allModels))
|
t.Errorf("expected 4 total models, got %d", len(allModels))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("integration with or_models.json", func(t *testing.T) {
|
t.Run("integration with or_models.json", func(t *testing.T) {
|
||||||
// Attempt to load the real data file from the project root
|
// Attempt to load the real data file from the project root
|
||||||
path := filepath.Join("..", "or_models.json")
|
path := filepath.Join("..", "or_models.json")
|
||||||
|
|||||||
29
popups.go
29
popups.go
@@ -51,7 +51,7 @@ func showModelSelectionPopup() {
|
|||||||
// Find the current model index to set as selected
|
// Find the current model index to set as selected
|
||||||
currentModelIndex := -1
|
currentModelIndex := -1
|
||||||
for i, model := range modelList {
|
for i, model := range modelList {
|
||||||
if model == chatBody.Model {
|
if strings.TrimPrefix(model, "(loaded) ") == chatBody.Model {
|
||||||
currentModelIndex = i
|
currentModelIndex = i
|
||||||
}
|
}
|
||||||
modelListWidget.AddItem(model, "", 0, nil)
|
modelListWidget.AddItem(model, "", 0, nil)
|
||||||
@@ -65,16 +65,19 @@ func showModelSelectionPopup() {
|
|||||||
chatBody.Model = modelName
|
chatBody.Model = modelName
|
||||||
cfg.CurrentModel = chatBody.Model
|
cfg.CurrentModel = chatBody.Model
|
||||||
pages.RemovePage("modelSelectionPopup")
|
pages.RemovePage("modelSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
updateStatusLine()
|
updateStatusLine()
|
||||||
})
|
})
|
||||||
modelListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
modelListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||||
if event.Key() == tcell.KeyEscape {
|
if event.Key() == tcell.KeyEscape {
|
||||||
pages.RemovePage("modelSelectionPopup")
|
pages.RemovePage("modelSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
||||||
pages.RemovePage("modelSelectionPopup")
|
pages.RemovePage("modelSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return event
|
return event
|
||||||
@@ -160,6 +163,7 @@ func showAPILinkSelectionPopup() {
|
|||||||
cfg.CurrentModel = chatBody.Model
|
cfg.CurrentModel = chatBody.Model
|
||||||
}
|
}
|
||||||
pages.RemovePage("apiLinkSelectionPopup")
|
pages.RemovePage("apiLinkSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
choseChunkParser()
|
choseChunkParser()
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
updateStatusLine()
|
updateStatusLine()
|
||||||
@@ -167,10 +171,12 @@ func showAPILinkSelectionPopup() {
|
|||||||
apiListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
apiListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||||
if event.Key() == tcell.KeyEscape {
|
if event.Key() == tcell.KeyEscape {
|
||||||
pages.RemovePage("apiLinkSelectionPopup")
|
pages.RemovePage("apiLinkSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
||||||
pages.RemovePage("apiLinkSelectionPopup")
|
pages.RemovePage("apiLinkSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return event
|
return event
|
||||||
@@ -230,6 +236,7 @@ func showUserRoleSelectionPopup() {
|
|||||||
textView.SetText(chatToText(filtered, cfg.ShowSys))
|
textView.SetText(chatToText(filtered, cfg.ShowSys))
|
||||||
// Remove the popup page
|
// Remove the popup page
|
||||||
pages.RemovePage("userRoleSelectionPopup")
|
pages.RemovePage("userRoleSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
// Update the status line to reflect the change
|
// Update the status line to reflect the change
|
||||||
updateStatusLine()
|
updateStatusLine()
|
||||||
colorText()
|
colorText()
|
||||||
@@ -237,10 +244,12 @@ func showUserRoleSelectionPopup() {
|
|||||||
roleListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
roleListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||||
if event.Key() == tcell.KeyEscape {
|
if event.Key() == tcell.KeyEscape {
|
||||||
pages.RemovePage("userRoleSelectionPopup")
|
pages.RemovePage("userRoleSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
||||||
pages.RemovePage("userRoleSelectionPopup")
|
pages.RemovePage("userRoleSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return event
|
return event
|
||||||
@@ -303,16 +312,19 @@ func showBotRoleSelectionPopup() {
|
|||||||
cfg.WriteNextMsgAsCompletionAgent = mainText
|
cfg.WriteNextMsgAsCompletionAgent = mainText
|
||||||
// Remove the popup page
|
// Remove the popup page
|
||||||
pages.RemovePage("botRoleSelectionPopup")
|
pages.RemovePage("botRoleSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
// Update the status line to reflect the change
|
// Update the status line to reflect the change
|
||||||
updateStatusLine()
|
updateStatusLine()
|
||||||
})
|
})
|
||||||
roleListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
roleListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||||
if event.Key() == tcell.KeyEscape {
|
if event.Key() == tcell.KeyEscape {
|
||||||
pages.RemovePage("botRoleSelectionPopup")
|
pages.RemovePage("botRoleSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
||||||
pages.RemovePage("botRoleSelectionPopup")
|
pages.RemovePage("botRoleSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return event
|
return event
|
||||||
@@ -364,14 +376,17 @@ func showFileCompletionPopup(filter string) {
|
|||||||
textArea.SetText(before+mainText, true)
|
textArea.SetText(before+mainText, true)
|
||||||
}
|
}
|
||||||
pages.RemovePage("fileCompletionPopup")
|
pages.RemovePage("fileCompletionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
})
|
})
|
||||||
widget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
widget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||||
if event.Key() == tcell.KeyEscape {
|
if event.Key() == tcell.KeyEscape {
|
||||||
pages.RemovePage("fileCompletionPopup")
|
pages.RemovePage("fileCompletionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
||||||
pages.RemovePage("fileCompletionPopup")
|
pages.RemovePage("fileCompletionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return event
|
return event
|
||||||
@@ -395,38 +410,30 @@ func updateWidgetColors(theme *tview.Theme) {
|
|||||||
fgColor := theme.PrimaryTextColor
|
fgColor := theme.PrimaryTextColor
|
||||||
borderColor := theme.BorderColor
|
borderColor := theme.BorderColor
|
||||||
titleColor := theme.TitleColor
|
titleColor := theme.TitleColor
|
||||||
|
|
||||||
textView.SetBackgroundColor(bgColor)
|
textView.SetBackgroundColor(bgColor)
|
||||||
textView.SetTextColor(fgColor)
|
textView.SetTextColor(fgColor)
|
||||||
textView.SetBorderColor(borderColor)
|
textView.SetBorderColor(borderColor)
|
||||||
textView.SetTitleColor(titleColor)
|
textView.SetTitleColor(titleColor)
|
||||||
|
|
||||||
textArea.SetBackgroundColor(bgColor)
|
textArea.SetBackgroundColor(bgColor)
|
||||||
textArea.SetBorderColor(borderColor)
|
textArea.SetBorderColor(borderColor)
|
||||||
textArea.SetTitleColor(titleColor)
|
textArea.SetTitleColor(titleColor)
|
||||||
textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
|
|
||||||
textArea.SetText(textArea.GetText(), true)
|
textArea.SetText(textArea.GetText(), true)
|
||||||
|
|
||||||
editArea.SetBackgroundColor(bgColor)
|
editArea.SetBackgroundColor(bgColor)
|
||||||
editArea.SetBorderColor(borderColor)
|
editArea.SetBorderColor(borderColor)
|
||||||
editArea.SetTitleColor(titleColor)
|
editArea.SetTitleColor(titleColor)
|
||||||
editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
|
|
||||||
editArea.SetText(editArea.GetText(), true)
|
editArea.SetText(editArea.GetText(), true)
|
||||||
|
|
||||||
statusLineWidget.SetBackgroundColor(bgColor)
|
statusLineWidget.SetBackgroundColor(bgColor)
|
||||||
statusLineWidget.SetTextColor(fgColor)
|
statusLineWidget.SetTextColor(fgColor)
|
||||||
statusLineWidget.SetBorderColor(borderColor)
|
statusLineWidget.SetBorderColor(borderColor)
|
||||||
statusLineWidget.SetTitleColor(titleColor)
|
statusLineWidget.SetTitleColor(titleColor)
|
||||||
|
|
||||||
helpView.SetBackgroundColor(bgColor)
|
helpView.SetBackgroundColor(bgColor)
|
||||||
helpView.SetTextColor(fgColor)
|
helpView.SetTextColor(fgColor)
|
||||||
helpView.SetBorderColor(borderColor)
|
helpView.SetBorderColor(borderColor)
|
||||||
helpView.SetTitleColor(titleColor)
|
helpView.SetTitleColor(titleColor)
|
||||||
|
|
||||||
searchField.SetBackgroundColor(bgColor)
|
searchField.SetBackgroundColor(bgColor)
|
||||||
searchField.SetBorderColor(borderColor)
|
searchField.SetBorderColor(borderColor)
|
||||||
searchField.SetTitleColor(titleColor)
|
searchField.SetTitleColor(titleColor)
|
||||||
@@ -453,7 +460,6 @@ func showColorschemeSelectionPopup() {
|
|||||||
schemeListWidget := tview.NewList().ShowSecondaryText(false).
|
schemeListWidget := tview.NewList().ShowSecondaryText(false).
|
||||||
SetSelectedBackgroundColor(tcell.ColorGray)
|
SetSelectedBackgroundColor(tcell.ColorGray)
|
||||||
schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true)
|
schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true)
|
||||||
|
|
||||||
currentScheme := "default"
|
currentScheme := "default"
|
||||||
for name := range colorschemes {
|
for name := range colorschemes {
|
||||||
if tview.Styles == colorschemes[name] {
|
if tview.Styles == colorschemes[name] {
|
||||||
@@ -484,14 +490,17 @@ func showColorschemeSelectionPopup() {
|
|||||||
}
|
}
|
||||||
// Remove the popup page
|
// Remove the popup page
|
||||||
pages.RemovePage("colorschemeSelectionPopup")
|
pages.RemovePage("colorschemeSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
})
|
})
|
||||||
schemeListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
schemeListWidget.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||||
if event.Key() == tcell.KeyEscape {
|
if event.Key() == tcell.KeyEscape {
|
||||||
pages.RemovePage("colorschemeSelectionPopup")
|
pages.RemovePage("colorschemeSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
||||||
pages.RemovePage("colorschemeSelectionPopup")
|
pages.RemovePage("colorschemeSelectionPopup")
|
||||||
|
app.SetFocus(textArea)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return event
|
return event
|
||||||
|
|||||||
@@ -115,9 +115,6 @@ func makePropsTable(props map[string]float32) *tview.Table {
|
|||||||
row++
|
row++
|
||||||
}
|
}
|
||||||
// Add checkboxes
|
// Add checkboxes
|
||||||
addCheckboxRow("Insert <think> tag (/completion only)", cfg.ThinkUse, func(checked bool) {
|
|
||||||
cfg.ThinkUse = checked
|
|
||||||
})
|
|
||||||
addCheckboxRow("RAG use", cfg.RAGEnabled, func(checked bool) {
|
addCheckboxRow("RAG use", cfg.RAGEnabled, func(checked bool) {
|
||||||
cfg.RAGEnabled = checked
|
cfg.RAGEnabled = checked
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -131,7 +131,6 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
|||||||
}
|
}
|
||||||
embeddings[data.Index] = data.Embedding
|
embeddings[data.Index] = data.Embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
return embeddings, nil
|
return embeddings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
181
rag/extractors.go
Normal file
181
rag/extractors.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package rag
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/zip"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/PuerkitoBio/goquery"
|
||||||
|
"github.com/ledongthuc/pdf"
|
||||||
|
"github.com/yuin/goldmark"
|
||||||
|
"github.com/yuin/goldmark/extension"
|
||||||
|
"github.com/yuin/goldmark/parser"
|
||||||
|
"github.com/yuin/goldmark/renderer/html"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExtractText(fpath string) (string, error) {
|
||||||
|
ext := strings.ToLower(path.Ext(fpath))
|
||||||
|
switch ext {
|
||||||
|
case ".txt":
|
||||||
|
return extractTextFromFile(fpath)
|
||||||
|
case ".md", ".markdown":
|
||||||
|
return extractTextFromMarkdown(fpath)
|
||||||
|
case ".html", ".htm":
|
||||||
|
return extractTextFromHtmlFile(fpath)
|
||||||
|
case ".epub":
|
||||||
|
return extractTextFromEpub(fpath)
|
||||||
|
case ".pdf":
|
||||||
|
return extractTextFromPdf(fpath)
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported file format: %s", ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromFile(fpath string) (string, error) {
|
||||||
|
data, err := os.ReadFile(fpath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromHtmlFile(fpath string) (string, error) {
|
||||||
|
data, err := os.ReadFile(fpath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return extractTextFromHtmlContent(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// non utf-8 encoding?
|
||||||
|
func extractTextFromHtmlContent(data []byte) (string, error) {
|
||||||
|
doc, err := goquery.NewDocumentFromReader(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Remove script and style tags
|
||||||
|
doc.Find("script, style, noscript").Each(func(i int, s *goquery.Selection) {
|
||||||
|
s.Remove()
|
||||||
|
})
|
||||||
|
// Get text and clean it
|
||||||
|
text := doc.Text()
|
||||||
|
// Collapse all whitespace (newlines, tabs, multiple spaces) into single spaces
|
||||||
|
cleaned := strings.Join(strings.Fields(text), " ")
|
||||||
|
return cleaned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromMarkdown(fpath string) (string, error) {
|
||||||
|
data, err := os.ReadFile(fpath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Convert markdown to HTML
|
||||||
|
md := goldmark.New(
|
||||||
|
goldmark.WithExtensions(extension.GFM),
|
||||||
|
goldmark.WithParserOptions(parser.WithAutoHeadingID()),
|
||||||
|
goldmark.WithRendererOptions(html.WithUnsafe()), // allow raw HTML if needed
|
||||||
|
)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := md.Convert(data, &buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Now extract text from the resulting HTML (using goquery or similar)
|
||||||
|
return extractTextFromHtmlContent(buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromEpub(fpath string) (string, error) {
|
||||||
|
r, err := zip.OpenReader(fpath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to open epub: %w", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, f := range r.File {
|
||||||
|
ext := strings.ToLower(path.Ext(f.Name))
|
||||||
|
if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip manifest, toc, ncx files - they don't contain book content
|
||||||
|
nameLower := strings.ToLower(f.Name)
|
||||||
|
if strings.Contains(nameLower, "toc") || strings.Contains(nameLower, "nav") ||
|
||||||
|
strings.Contains(nameLower, "manifest") || strings.Contains(nameLower, ".opf") ||
|
||||||
|
strings.HasSuffix(nameLower, ".ncx") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, err := f.Open()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
sb.WriteString(f.Name)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
buf, readErr := io.ReadAll(rc)
|
||||||
|
rc.Close()
|
||||||
|
if readErr == nil {
|
||||||
|
sb.WriteString(stripHTML(string(buf)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sb.Len() == 0 {
|
||||||
|
return "", errors.New("no content extracted from epub")
|
||||||
|
}
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripHTML(html string) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
inTag := false
|
||||||
|
for _, r := range html {
|
||||||
|
switch r {
|
||||||
|
case '<':
|
||||||
|
inTag = true
|
||||||
|
case '>':
|
||||||
|
inTag = false
|
||||||
|
default:
|
||||||
|
if !inTag {
|
||||||
|
sb.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromPdf(fpath string) (string, error) {
|
||||||
|
_, err := exec.LookPath("pdftotext")
|
||||||
|
if err == nil {
|
||||||
|
out, err := exec.Command("pdftotext", "-layout", fpath, "-").Output()
|
||||||
|
if err == nil && len(out) > 0 {
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return extractTextFromPdfPureGo(fpath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromPdfPureGo(fpath string) (string, error) {
|
||||||
|
df, r, err := pdf.Open(fpath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to open pdf: %w", err)
|
||||||
|
}
|
||||||
|
defer df.Close()
|
||||||
|
textReader, err := r.GetPlainText()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to extract text from pdf: %w", err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, err = io.Copy(&buf, textReader)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read pdf text: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
504
rag/rag.go
504
rag/rag.go
@@ -7,8 +7,9 @@ import (
|
|||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
"gf-lt/storage"
|
"gf-lt/storage"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
|
||||||
"path"
|
"path"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -23,19 +24,18 @@ var (
|
|||||||
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type RAG struct {
|
type RAG struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
store storage.FullRepo
|
store storage.FullRepo
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
embedder Embedder
|
embedder Embedder
|
||||||
storage *VectorStorage
|
storage *VectorStorage
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
||||||
// Initialize with API embedder by default, could be configurable later
|
// Initialize with API embedder by default, could be configurable later
|
||||||
embedder := NewAPIEmbedder(l, cfg)
|
embedder := NewAPIEmbedder(l, cfg)
|
||||||
|
|
||||||
rag := &RAG{
|
rag := &RAG{
|
||||||
logger: l,
|
logger: l,
|
||||||
store: s,
|
store: s,
|
||||||
@@ -54,7 +54,9 @@ func wordCounter(sentence string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) LoadRAG(fpath string) error {
|
func (r *RAG) LoadRAG(fpath string) error {
|
||||||
data, err := os.ReadFile(fpath)
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
fileText, err := ExtractText(fpath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -63,10 +65,7 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
case LongJobStatusCh <- LoadedFileRAGStatus:
|
case LongJobStatusCh <- LoadedFileRAGStatus:
|
||||||
default:
|
default:
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fileText := string(data)
|
|
||||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
tokenizer, err := english.NewSentenceTokenizer(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -76,19 +75,16 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
for i, s := range sentences {
|
for i, s := range sentences {
|
||||||
sents[i] = s.Text
|
sents[i] = s.Text
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group sentences into paragraphs based on word limit
|
// Group sentences into paragraphs based on word limit
|
||||||
paragraphs := []string{}
|
paragraphs := []string{}
|
||||||
par := strings.Builder{}
|
par := strings.Builder{}
|
||||||
for i := 0; i < len(sents); i++ {
|
for i := 0; i < len(sents); i++ {
|
||||||
// Only add sentences that aren't empty
|
|
||||||
if strings.TrimSpace(sents[i]) != "" {
|
if strings.TrimSpace(sents[i]) != "" {
|
||||||
if par.Len() > 0 {
|
if par.Len() > 0 {
|
||||||
par.WriteString(" ") // Add space between sentences
|
par.WriteString(" ")
|
||||||
}
|
}
|
||||||
par.WriteString(sents[i])
|
par.WriteString(sents[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
|
if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
|
||||||
paragraph := strings.TrimSpace(par.String())
|
paragraph := strings.TrimSpace(par.String())
|
||||||
if paragraph != "" {
|
if paragraph != "" {
|
||||||
@@ -97,7 +93,6 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
par.Reset()
|
par.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle any remaining content in the paragraph buffer
|
// Handle any remaining content in the paragraph buffer
|
||||||
if par.Len() > 0 {
|
if par.Len() > 0 {
|
||||||
paragraph := strings.TrimSpace(par.String())
|
paragraph := strings.TrimSpace(par.String())
|
||||||
@@ -105,217 +100,84 @@ func (r *RAG) LoadRAG(fpath string) error {
|
|||||||
paragraphs = append(paragraphs, paragraph)
|
paragraphs = append(paragraphs, paragraph)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjust batch size if needed
|
// Adjust batch size if needed
|
||||||
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
|
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
|
||||||
r.cfg.RAGBatchSize = len(paragraphs)
|
r.cfg.RAGBatchSize = len(paragraphs)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(paragraphs) == 0 {
|
if len(paragraphs) == 0 {
|
||||||
return errors.New("no valid paragraphs found in file")
|
return errors.New("no valid paragraphs found in file")
|
||||||
}
|
}
|
||||||
|
// Process paragraphs in batches synchronously
|
||||||
var (
|
batchCount := 0
|
||||||
maxChSize = 100
|
for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize {
|
||||||
left = 0
|
end := i + r.cfg.RAGBatchSize
|
||||||
right = r.cfg.RAGBatchSize
|
if end > len(paragraphs) {
|
||||||
batchCh = make(chan map[int][]string, maxChSize)
|
end = len(paragraphs)
|
||||||
vectorCh = make(chan []models.VectorRow, maxChSize)
|
|
||||||
errCh = make(chan error, 1)
|
|
||||||
wg = new(sync.WaitGroup)
|
|
||||||
lock = new(sync.Mutex)
|
|
||||||
)
|
|
||||||
|
|
||||||
defer close(errCh)
|
|
||||||
defer close(batchCh)
|
|
||||||
|
|
||||||
// Fill input channel with batches
|
|
||||||
ctn := 0
|
|
||||||
totalParagraphs := len(paragraphs)
|
|
||||||
for {
|
|
||||||
if right > totalParagraphs {
|
|
||||||
batchCh <- map[int][]string{left: paragraphs[left:]}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
batchCh <- map[int][]string{left: paragraphs[left:right]}
|
batch := paragraphs[i:end]
|
||||||
left, right = right, right+r.cfg.RAGBatchSize
|
batchCount++
|
||||||
ctn++
|
// Filter empty paragraphs
|
||||||
|
nonEmptyBatch := make([]string, 0, len(batch))
|
||||||
|
for _, p := range batch {
|
||||||
|
if strings.TrimSpace(p) != "" {
|
||||||
|
nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
|
||||||
}
|
}
|
||||||
|
|
||||||
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents))
|
|
||||||
r.logger.Debug(finishedBatchesMsg)
|
|
||||||
select {
|
|
||||||
case LongJobStatusCh <- finishedBatchesMsg:
|
|
||||||
default:
|
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg)
|
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
|
||||||
}
|
}
|
||||||
|
if len(nonEmptyBatch) == 0 {
|
||||||
// Start worker goroutines with WaitGroup
|
continue
|
||||||
wg.Add(int(r.cfg.RAGWorkers))
|
|
||||||
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
|
|
||||||
go func(workerID int) {
|
|
||||||
defer wg.Done()
|
|
||||||
r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath))
|
|
||||||
}(w)
|
|
||||||
}
|
}
|
||||||
|
// Embed the batch
|
||||||
// Use a goroutine to close the batchCh when all batches are sent
|
embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch)
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
close(vectorCh) // Close vectorCh when all workers are done
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Check for errors from workers
|
|
||||||
// Use a non-blocking check for errors
|
|
||||||
select {
|
|
||||||
case err := <-errCh:
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("error during RAG processing", "error", err)
|
r.logger.Error("failed to embed batch", "error", err, "batch", batchCount)
|
||||||
|
select {
|
||||||
|
case LongJobStatusCh <- ErrRAGStatus:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to embed batch %d: %w", batchCount, err)
|
||||||
|
}
|
||||||
|
if len(embeddings) != len(nonEmptyBatch) {
|
||||||
|
err := errors.New("embedding count mismatch")
|
||||||
|
r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
// Write vectors to storage
|
||||||
// No immediate error, continue
|
filename := path.Base(fpath)
|
||||||
|
for j, text := range nonEmptyBatch {
|
||||||
|
vector := models.VectorRow{
|
||||||
|
Embeddings: embeddings[j],
|
||||||
|
RawText: text,
|
||||||
|
Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j),
|
||||||
|
FileName: filename,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write vectors to storage - this will block until vectorCh is closed
|
|
||||||
return r.writeVectors(vectorCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
|
|
||||||
for {
|
|
||||||
for batch := range vectorCh {
|
|
||||||
for _, vector := range batch {
|
|
||||||
if err := r.storage.WriteVector(&vector); err != nil {
|
if err := r.storage.WriteVector(&vector); err != nil {
|
||||||
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
|
||||||
select {
|
select {
|
||||||
case LongJobStatusCh <- ErrRAGStatus:
|
case LongJobStatusCh <- ErrRAGStatus:
|
||||||
default:
|
default:
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
|
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
|
||||||
}
|
}
|
||||||
return err // Stop the entire RAG operation on DB error
|
return fmt.Errorf("failed to write vector: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
|
r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch))
|
||||||
if len(vectorCh) == 0 {
|
// Send progress status
|
||||||
r.logger.Debug("finished writing vectors")
|
statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize)
|
||||||
|
select {
|
||||||
|
case LongJobStatusCh <- statusMsg:
|
||||||
|
default:
|
||||||
|
r.logger.Warn("LongJobStatusCh channel full, dropping message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.logger.Debug("finished writing vectors", "batches", batchCount)
|
||||||
select {
|
select {
|
||||||
case LongJobStatusCh <- FinishedRAGStatus:
|
case LongJobStatusCh <- FinishedRAGStatus:
|
||||||
default:
|
default:
|
||||||
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
|
|
||||||
vectorCh chan<- []models.VectorRow, errCh chan error, filename string) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
// For errCh, make sure we only send if there's actually an error and the channel can accept it
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case errCh <- err:
|
|
||||||
default:
|
|
||||||
// errCh might be full or closed, log but don't panic
|
|
||||||
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
lock.Lock()
|
|
||||||
if len(inputCh) == 0 {
|
|
||||||
lock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case linesMap := <-inputCh:
|
|
||||||
for leftI, lines := range linesMap {
|
|
||||||
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
|
|
||||||
r.logger.Error("error fetching embeddings", "error", err, "worker", id)
|
|
||||||
lock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lock.Unlock()
|
|
||||||
case err = <-errCh:
|
|
||||||
r.logger.Error("got an error from error channel", "error", err)
|
|
||||||
lock.Unlock()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
lock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id)
|
|
||||||
statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
|
|
||||||
select {
|
|
||||||
case LongJobStatusCh <- statusMsg:
|
|
||||||
default:
|
|
||||||
r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
|
|
||||||
// Channel is full or closed, ignore the message to prevent panic
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
|
|
||||||
// Filter out empty lines before sending to embedder
|
|
||||||
nonEmptyLines := make([]string, 0, len(lines))
|
|
||||||
for _, line := range lines {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
if trimmed != "" {
|
|
||||||
nonEmptyLines = append(nonEmptyLines, trimmed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip if no non-empty lines
|
|
||||||
if len(nonEmptyLines) == 0 {
|
|
||||||
// Send empty result but don't error
|
|
||||||
vectorCh <- []models.VectorRow{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
|
|
||||||
if err != nil {
|
|
||||||
r.logger.Error("failed to embed lines", "err", err.Error())
|
|
||||||
errCh <- err
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(embeddings) == 0 {
|
|
||||||
err := errors.New("no embeddings returned")
|
|
||||||
r.logger.Error("empty embeddings")
|
|
||||||
errCh <- err
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(embeddings) != len(nonEmptyLines) {
|
|
||||||
err := errors.New("mismatch between number of lines and embeddings returned")
|
|
||||||
r.logger.Error("embedding mismatch", "err", err.Error())
|
|
||||||
errCh <- err
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a VectorRow for each line in the batch
|
|
||||||
vectors := make([]models.VectorRow, len(nonEmptyLines))
|
|
||||||
for i, line := range nonEmptyLines {
|
|
||||||
vectors[i] = models.VectorRow{
|
|
||||||
Embeddings: embeddings[i],
|
|
||||||
RawText: line,
|
|
||||||
Slug: fmt.Sprintf("%s_%d", slug, i),
|
|
||||||
FileName: filename,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
vectorCh <- vectors
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
func (r *RAG) LineToVector(line string) ([]float32, error) {
|
||||||
return r.embedder.Embed(line)
|
return r.embedder.Embed(line)
|
||||||
@@ -332,3 +194,259 @@ func (r *RAG) ListLoaded() ([]string, error) {
|
|||||||
func (r *RAG) RemoveFile(filename string) error {
|
func (r *RAG) RemoveFile(filename string) error {
|
||||||
return r.storage.RemoveEmbByFileName(filename)
|
return r.storage.RemoveEmbByFileName(filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
|
||||||
|
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
|
||||||
|
stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *RAG) RefineQuery(query string) string {
|
||||||
|
original := query
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) == 0 {
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
if len(query) <= 3 {
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
query = strings.ToLower(query)
|
||||||
|
for _, stopWord := range stopWords {
|
||||||
|
wordPattern := `\b` + stopWord + `\b`
|
||||||
|
re := regexp.MustCompile(wordPattern)
|
||||||
|
query = re.ReplaceAllString(query, "")
|
||||||
|
}
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) < 5 {
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
if queryRefinementPattern.MatchString(original) {
|
||||||
|
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
|
||||||
|
cleaned = strings.TrimSpace(cleaned)
|
||||||
|
if len(cleaned) >= 5 {
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
query = r.extractImportantPhrases(query)
|
||||||
|
if len(query) < 5 {
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) extractImportantPhrases(query string) string {
|
||||||
|
words := strings.Fields(query)
|
||||||
|
var important []string
|
||||||
|
for _, word := range words {
|
||||||
|
word = strings.Trim(word, ".,!?;:'\"()[]{}")
|
||||||
|
isImportant := false
|
||||||
|
for _, kw := range importantKeywords {
|
||||||
|
if strings.Contains(strings.ToLower(word), kw) {
|
||||||
|
isImportant = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isImportant || len(word) > 3 {
|
||||||
|
important = append(important, word)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(important) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return strings.Join(important, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) GenerateQueryVariations(query string) []string {
|
||||||
|
variations := []string{query}
|
||||||
|
if len(query) < 5 {
|
||||||
|
return variations
|
||||||
|
}
|
||||||
|
parts := strings.Fields(query)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return variations
|
||||||
|
}
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
trimmed := strings.Join(parts[:len(parts)-1], " ")
|
||||||
|
if len(trimmed) >= 5 {
|
||||||
|
variations = append(variations, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
trimmed := strings.Join(parts[1:], " ")
|
||||||
|
if len(trimmed) >= 5 {
|
||||||
|
variations = append(variations, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(query, " explanation") {
|
||||||
|
variations = append(variations, query+" explanation")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(query, "what is ") {
|
||||||
|
variations = append(variations, "what is "+query)
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(query, " details") {
|
||||||
|
variations = append(variations, query+" details")
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(query, " summary") {
|
||||||
|
variations = append(variations, query+" summary")
|
||||||
|
}
|
||||||
|
return variations
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
|
||||||
|
type scoredResult struct {
|
||||||
|
row models.VectorRow
|
||||||
|
distance float32
|
||||||
|
}
|
||||||
|
scored := make([]scoredResult, 0, len(results))
|
||||||
|
for i := range results {
|
||||||
|
row := results[i]
|
||||||
|
|
||||||
|
score := float32(0)
|
||||||
|
rawTextLower := strings.ToLower(row.RawText)
|
||||||
|
queryLower := strings.ToLower(query)
|
||||||
|
if strings.Contains(rawTextLower, queryLower) {
|
||||||
|
score += 10
|
||||||
|
}
|
||||||
|
queryWords := strings.Fields(queryLower)
|
||||||
|
matchCount := 0
|
||||||
|
for _, word := range queryWords {
|
||||||
|
if len(word) > 2 && strings.Contains(rawTextLower, word) {
|
||||||
|
matchCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(queryWords) > 0 {
|
||||||
|
score += float32(matchCount) / float32(len(queryWords)) * 5
|
||||||
|
}
|
||||||
|
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
||||||
|
score += 3
|
||||||
|
}
|
||||||
|
distance := row.Distance - score/100
|
||||||
|
scored = append(scored, scoredResult{row: row, distance: distance})
|
||||||
|
}
|
||||||
|
sort.Slice(scored, func(i, j int) bool {
|
||||||
|
return scored[i].distance < scored[j].distance
|
||||||
|
})
|
||||||
|
unique := make([]models.VectorRow, 0)
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for i := range scored {
|
||||||
|
if !seen[scored[i].row.Slug] {
|
||||||
|
seen[scored[i].row.Slug] = true
|
||||||
|
unique = append(unique, scored[i].row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(unique) > 10 {
|
||||||
|
unique = unique[:10]
|
||||||
|
}
|
||||||
|
return unique
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
|
||||||
|
if len(results) == 0 {
|
||||||
|
return "No relevant information found in the vector database.", nil
|
||||||
|
}
|
||||||
|
var contextBuilder strings.Builder
|
||||||
|
contextBuilder.WriteString("User Query: ")
|
||||||
|
contextBuilder.WriteString(query)
|
||||||
|
contextBuilder.WriteString("\n\nRetrieved Context:\n")
|
||||||
|
for i, row := range results {
|
||||||
|
fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName)
|
||||||
|
contextBuilder.WriteString(row.RawText)
|
||||||
|
contextBuilder.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
contextBuilder.WriteString("Instructions: ")
|
||||||
|
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
|
||||||
|
contextBuilder.WriteString("Extract only the most relevant information. ")
|
||||||
|
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
|
||||||
|
contextBuilder.WriteString("Cite sources by filename when relevant. ")
|
||||||
|
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
|
||||||
|
synthesisPrompt := contextBuilder.String()
|
||||||
|
emb, err := r.LineToVector(synthesisPrompt)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("failed to embed synthesis prompt", "error", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
embResp := &models.EmbeddingResp{
|
||||||
|
Embedding: emb,
|
||||||
|
Index: 0,
|
||||||
|
}
|
||||||
|
topResults, err := r.SearchEmb(embResp)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("failed to search for synthesis context", "error", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
|
||||||
|
return topResults[0].RawText, nil
|
||||||
|
}
|
||||||
|
var finalAnswer strings.Builder
|
||||||
|
finalAnswer.WriteString("Based on the retrieved context:\n\n")
|
||||||
|
for i, row := range results {
|
||||||
|
if i >= 5 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))
|
||||||
|
}
|
||||||
|
return finalAnswer.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateString(s string, maxLen int) string {
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||||
|
refined := r.RefineQuery(query)
|
||||||
|
variations := r.GenerateQueryVariations(refined)
|
||||||
|
allResults := make([]models.VectorRow, 0)
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, q := range variations {
|
||||||
|
emb, err := r.LineToVector(q)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("failed to embed query variation", "error", err, "query", q)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
embResp := &models.EmbeddingResp{
|
||||||
|
Embedding: emb,
|
||||||
|
Index: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := r.SearchEmb(embResp)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("failed to search embeddings", "error", err, "query", q)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, row := range results {
|
||||||
|
if !seen[row.Slug] {
|
||||||
|
seen[row.Slug] = true
|
||||||
|
allResults = append(allResults, row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reranked := r.RerankResults(allResults, query)
|
||||||
|
if len(reranked) > limit {
|
||||||
|
reranked = reranked[:limit]
|
||||||
|
}
|
||||||
|
return reranked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ragInstance *RAG
|
||||||
|
ragOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
|
||||||
|
ragOnce.Do(func() {
|
||||||
|
if c == nil || l == nil || s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ragInstance = New(l, s, c)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetInstance() *RAG {
|
||||||
|
return ragInstance
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// SerializeVector converts []float32 to binary blob
|
// SerializeVector converts []float32 to binary blob
|
||||||
func SerializeVector(vec []float32) []byte {
|
func SerializeVector(vec []float32) []byte {
|
||||||
buf := make([]byte, len(vec)*4) // 4 bytes per float32
|
buf := make([]byte, len(vec)*4) // 4 bytes per float32
|
||||||
@@ -66,17 +65,14 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
|||||||
|
|
||||||
// Serialize the embeddings to binary
|
// Serialize the embeddings to binary
|
||||||
serializedEmbeddings := SerializeVector(row.Embeddings)
|
serializedEmbeddings := SerializeVector(row.Embeddings)
|
||||||
|
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
||||||
tableName,
|
tableName,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
||||||
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,11 +91,9 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
|
|||||||
4096: true,
|
4096: true,
|
||||||
5120: true,
|
5120: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if supportedSizes[size] {
|
if supportedSizes[size] {
|
||||||
return fmt.Sprintf("embeddings_%d", size), nil
|
return fmt.Sprintf("embeddings_%d", size), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("no table for embedding size of %d", size)
|
return "", fmt.Errorf("no table for embedding size of %d", size)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,9 +120,7 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
var topResults []SearchResult
|
var topResults []SearchResult
|
||||||
|
|
||||||
// Process vectors one by one to avoid loading everything into memory
|
// Process vectors one by one to avoid loading everything into memory
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
@@ -176,14 +168,12 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
results = append(results, result.vector)
|
results = append(results, result.vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFiles returns a list of all loaded files
|
// ListFiles returns a list of all loaded files
|
||||||
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
||||||
fileLists := make([][]string, 0)
|
fileLists := make([][]string, 0)
|
||||||
|
|
||||||
// Query all supported tables and combine results
|
// Query all supported tables and combine results
|
||||||
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
for _, size := range embeddingSizes {
|
for _, size := range embeddingSizes {
|
||||||
@@ -219,14 +209,12 @@ func (vs *VectorStorage) ListFiles() ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return allFiles, nil
|
return allFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
||||||
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
for _, size := range embeddingSizes {
|
for _, size := range embeddingSizes {
|
||||||
table := fmt.Sprintf("embeddings_%d", size)
|
table := fmt.Sprintf("embeddings_%d", size)
|
||||||
@@ -235,11 +223,9 @@ func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
|||||||
errors = append(errors, err.Error())
|
errors = append(errors, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
|
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,18 +234,15 @@ func cosineSimilarity(a, b []float32) float32 {
|
|||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var dotProduct, normA, normB float32
|
var dotProduct, normA, normB float32
|
||||||
for i := 0; i < len(a); i++ {
|
for i := 0; i < len(a); i++ {
|
||||||
dotProduct += a[i] * b[i]
|
dotProduct += a[i] * b[i]
|
||||||
normA += a[i] * a[i]
|
normA += a[i] * a[i]
|
||||||
normB += b[i] * b[i]
|
normB += b[i] * b[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if normA == 0 || normB == 0 {
|
if normA == 0 || normB == 0 {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,4 +258,3 @@ func sqrt(f float32) float32 {
|
|||||||
}
|
}
|
||||||
return guess
|
return guess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,6 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
p := ProviderSQL{db: db, logger: logger}
|
p := ProviderSQL{db: db, logger: logger}
|
||||||
|
|
||||||
p.Migrate()
|
p.Migrate()
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,12 +73,9 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
serializedEmbeddings := SerializeVector(row.Embeddings)
|
serializedEmbeddings := SerializeVector(row.Embeddings)
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
|
query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
|
||||||
_, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
|
_, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,27 +84,22 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
||||||
rows, err := p.db.Query(querySQL)
|
rows, err := p.db.Query(querySQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
type SearchResult struct {
|
type SearchResult struct {
|
||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
var topResults []SearchResult
|
var topResults []SearchResult
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
embeddingsBlob []byte
|
embeddingsBlob []byte
|
||||||
slug, rawText, fileName string
|
slug, rawText, fileName string
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
|
if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -152,7 +144,6 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
results[i] = result.vector
|
results[i] = result.vector
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,18 +152,15 @@ func cosineSimilarity(a, b []float32) float32 {
|
|||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var dotProduct, normA, normB float32
|
var dotProduct, normA, normB float32
|
||||||
for i := 0; i < len(a); i++ {
|
for i := 0; i < len(a); i++ {
|
||||||
dotProduct += a[i] * b[i]
|
dotProduct += a[i] * b[i]
|
||||||
normA += a[i] * a[i]
|
normA += a[i] * a[i]
|
||||||
normB += b[i] * b[i]
|
normB += b[i] * b[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if normA == 0 || normB == 0 {
|
if normA == 0 || normB == 0 {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,13 +217,11 @@ func (p ProviderSQL) ListFiles() ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return allFiles, nil
|
return allFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
tableNames := []string{
|
tableNames := []string{
|
||||||
"embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
|
"embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
|
||||||
"embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
|
"embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
|
||||||
@@ -246,10 +232,8 @@ func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
|||||||
errors = append(errors, err.Error())
|
errors = append(errors, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
return fmt.Errorf("errors occurred: %v", errors)
|
return fmt.Errorf("errors occurred: %v", errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
390
tables.go
390
tables.go
@@ -236,9 +236,59 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// nolint:unused
|
// nolint:unused
|
||||||
func makeRAGTable(fileList []string) *tview.Flex {
|
func formatSize(size int64) string {
|
||||||
actions := []string{"load", "delete"}
|
units := []string{"B", "KB", "MB", "GB", "TB"}
|
||||||
rows, cols := len(fileList), len(actions)+1
|
i := 0
|
||||||
|
s := float64(size)
|
||||||
|
for s >= 1024 && i < len(units)-1 {
|
||||||
|
s /= 1024
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.1f%s", s, units[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
type ragFileInfo struct {
|
||||||
|
name string
|
||||||
|
inRAGDir bool
|
||||||
|
isLoaded bool
|
||||||
|
fullPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
||||||
|
// Build set of loaded files for quick lookup
|
||||||
|
loadedSet := make(map[string]bool)
|
||||||
|
for _, f := range loadedFiles {
|
||||||
|
loadedSet[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build merged list: files from ragdir + orphaned files from DB
|
||||||
|
ragFiles := make([]ragFileInfo, 0, len(fileList)+len(loadedFiles))
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
// Add files from ragdir
|
||||||
|
for _, f := range fileList {
|
||||||
|
ragFiles = append(ragFiles, ragFileInfo{
|
||||||
|
name: f,
|
||||||
|
inRAGDir: true,
|
||||||
|
isLoaded: loadedSet[f],
|
||||||
|
fullPath: path.Join(cfg.RAGDir, f),
|
||||||
|
})
|
||||||
|
seen[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add orphaned files (in DB but not in ragdir)
|
||||||
|
for _, f := range loadedFiles {
|
||||||
|
if !seen[f] {
|
||||||
|
ragFiles = append(ragFiles, ragFileInfo{
|
||||||
|
name: f,
|
||||||
|
inRAGDir: false,
|
||||||
|
isLoaded: true,
|
||||||
|
fullPath: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows := len(ragFiles)
|
||||||
|
cols := 4 // File Name | Preview | Action | Delete
|
||||||
fileTable := tview.NewTable().
|
fileTable := tview.NewTable().
|
||||||
SetBorders(true)
|
SetBorders(true)
|
||||||
longStatusView := tview.NewTextView()
|
longStatusView := tview.NewTextView()
|
||||||
@@ -252,41 +302,92 @@ func makeRAGTable(fileList []string) *tview.Flex {
|
|||||||
AddItem(fileTable, 0, 60, true)
|
AddItem(fileTable, 0, 60, true)
|
||||||
// Add the exit option as the first row (row 0)
|
// Add the exit option as the first row (row 0)
|
||||||
fileTable.SetCell(0, 0,
|
fileTable.SetCell(0, 0,
|
||||||
tview.NewTableCell("Exit RAG manager").
|
tview.NewTableCell("File Name").
|
||||||
SetTextColor(tcell.ColorWhite).
|
SetTextColor(tcell.ColorWhite).
|
||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
fileTable.SetCell(0, 1,
|
fileTable.SetCell(0, 1,
|
||||||
tview.NewTableCell("(Close without action)").
|
tview.NewTableCell("Preview").
|
||||||
SetTextColor(tcell.ColorGray).
|
SetTextColor(tcell.ColorWhite).
|
||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
fileTable.SetCell(0, 2,
|
fileTable.SetCell(0, 2,
|
||||||
tview.NewTableCell("exit").
|
tview.NewTableCell("Load/Unload").
|
||||||
SetTextColor(tcell.ColorGray).
|
SetTextColor(tcell.ColorWhite).
|
||||||
SetAlign(tview.AlignCenter))
|
SetAlign(tview.AlignCenter).
|
||||||
|
SetSelectable(false))
|
||||||
|
fileTable.SetCell(0, 3,
|
||||||
|
tview.NewTableCell("Delete").
|
||||||
|
SetTextColor(tcell.ColorWhite).
|
||||||
|
SetAlign(tview.AlignCenter).
|
||||||
|
SetSelectable(false))
|
||||||
// Add the file rows starting from row 1
|
// Add the file rows starting from row 1
|
||||||
for r := 0; r < rows; r++ {
|
for r := 0; r < rows; r++ {
|
||||||
|
f := ragFiles[r]
|
||||||
for c := 0; c < cols; c++ {
|
for c := 0; c < cols; c++ {
|
||||||
color := tcell.ColorWhite
|
color := tcell.ColorWhite
|
||||||
switch {
|
switch c {
|
||||||
case c < 1:
|
case 0:
|
||||||
fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0
|
displayName := f.name
|
||||||
tview.NewTableCell(fileList[r]).
|
if !f.inRAGDir {
|
||||||
|
displayName = f.name + " (orphaned)"
|
||||||
|
}
|
||||||
|
fileTable.SetCell(r+1, c,
|
||||||
|
tview.NewTableCell(displayName).
|
||||||
SetTextColor(color).
|
SetTextColor(color).
|
||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
case c == 1: // Action description column - not selectable
|
case 1:
|
||||||
fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0
|
if !f.inRAGDir {
|
||||||
tview.NewTableCell("(Action)").
|
// Orphaned file - no preview available
|
||||||
|
fileTable.SetCell(r+1, c,
|
||||||
|
tview.NewTableCell("not in ragdir").
|
||||||
|
SetTextColor(tcell.ColorYellow).
|
||||||
|
SetAlign(tview.AlignCenter).
|
||||||
|
SetSelectable(false))
|
||||||
|
} else if fi, err := os.Stat(f.fullPath); err == nil {
|
||||||
|
size := fi.Size()
|
||||||
|
modTime := fi.ModTime()
|
||||||
|
preview := fmt.Sprintf("%s | %s", formatSize(size), modTime.Format("2006-01-02 15:04"))
|
||||||
|
fileTable.SetCell(r+1, c,
|
||||||
|
tview.NewTableCell(preview).
|
||||||
SetTextColor(color).
|
SetTextColor(color).
|
||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
default: // Action button column - selectable
|
} else {
|
||||||
fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0
|
fileTable.SetCell(r+1, c,
|
||||||
tview.NewTableCell(actions[c-1]).
|
tview.NewTableCell("error").
|
||||||
|
SetTextColor(color).
|
||||||
|
SetAlign(tview.AlignCenter).
|
||||||
|
SetSelectable(false))
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
actionText := "load"
|
||||||
|
if f.isLoaded {
|
||||||
|
actionText = "unload"
|
||||||
|
}
|
||||||
|
if !f.inRAGDir {
|
||||||
|
// Orphaned file - can only unload
|
||||||
|
actionText = "unload"
|
||||||
|
}
|
||||||
|
fileTable.SetCell(r+1, c,
|
||||||
|
tview.NewTableCell(actionText).
|
||||||
SetTextColor(color).
|
SetTextColor(color).
|
||||||
SetAlign(tview.AlignCenter))
|
SetAlign(tview.AlignCenter))
|
||||||
|
case 3:
|
||||||
|
if !f.inRAGDir {
|
||||||
|
// Orphaned file - cannot delete from ragdir (not there)
|
||||||
|
fileTable.SetCell(r+1, c,
|
||||||
|
tview.NewTableCell("-").
|
||||||
|
SetTextColor(tcell.ColorDarkGray).
|
||||||
|
SetAlign(tview.AlignCenter).
|
||||||
|
SetSelectable(false))
|
||||||
|
} else {
|
||||||
|
fileTable.SetCell(r+1, c,
|
||||||
|
tview.NewTableCell("delete").
|
||||||
|
SetTextColor(color).
|
||||||
|
SetAlign(tview.AlignCenter))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -318,7 +419,7 @@ func makeRAGTable(fileList []string) *tview.Flex {
|
|||||||
}()
|
}()
|
||||||
fileTable.Select(0, 0).
|
fileTable.Select(0, 0).
|
||||||
SetFixed(1, 1).
|
SetFixed(1, 1).
|
||||||
SetSelectable(true, false).
|
SetSelectable(true, true).
|
||||||
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
||||||
SetDoneFunc(func(key tcell.Key) {
|
SetDoneFunc(func(key tcell.Key) {
|
||||||
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') || key == tcell.KeyCtrlX {
|
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') || key == tcell.KeyCtrlX {
|
||||||
@@ -335,30 +436,58 @@ func makeRAGTable(fileList []string) *tview.Flex {
|
|||||||
}
|
}
|
||||||
// defer pages.RemovePage(RAGPage)
|
// defer pages.RemovePage(RAGPage)
|
||||||
tc := fileTable.GetCell(row, column)
|
tc := fileTable.GetCell(row, column)
|
||||||
|
tc.SetTextColor(tcell.ColorRed)
|
||||||
|
fileTable.SetSelectable(false, false)
|
||||||
// Check if the selected row is the exit row (row 0) - do this first to avoid index issues
|
// Check if the selected row is the exit row (row 0) - do this first to avoid index issues
|
||||||
if row == 0 {
|
if row == 0 {
|
||||||
pages.RemovePage(RAGPage)
|
pages.RemovePage(RAGPage)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// For file rows, get the filename (row index - 1 because of the exit row at index 0)
|
// For file rows, get the file info (row index - 1 because of the exit row at index 0)
|
||||||
fpath := fileList[row-1] // -1 to account for the exit row at index 0
|
f := ragFiles[row-1]
|
||||||
// notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text)
|
// Handle "-" case (orphaned file with no delete option)
|
||||||
|
if tc.Text == "-" {
|
||||||
|
pages.RemovePage(RAGPage)
|
||||||
|
return
|
||||||
|
}
|
||||||
switch tc.Text {
|
switch tc.Text {
|
||||||
case "load":
|
case "load":
|
||||||
fpath = path.Join(cfg.RAGDir, fpath)
|
fpath := path.Join(cfg.RAGDir, f.name)
|
||||||
longStatusView.SetText("clicked load")
|
longStatusView.SetText("clicked load")
|
||||||
go func() {
|
go func() {
|
||||||
if err := ragger.LoadRAG(fpath); err != nil {
|
if err := ragger.LoadRAG(fpath); err != nil {
|
||||||
logger.Error("failed to embed file", "chat", fpath, "error", err)
|
logger.Error("failed to embed file", "chat", fpath, "error", err)
|
||||||
_ = notifyUser("RAG", "failed to embed file; error: "+err.Error())
|
_ = notifyUser("RAG", "failed to embed file; error: "+err.Error())
|
||||||
errCh <- err
|
app.QueueUpdate(func() {
|
||||||
// pages.RemovePage(RAGPage)
|
pages.RemovePage(RAGPage)
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
_ = notifyUser("RAG", "file loaded successfully")
|
||||||
|
app.QueueUpdate(func() {
|
||||||
|
pages.RemovePage(RAGPage)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
case "unload":
|
||||||
|
longStatusView.SetText("clicked unload")
|
||||||
|
go func() {
|
||||||
|
if err := ragger.RemoveFile(f.name); err != nil {
|
||||||
|
logger.Error("failed to unload file from RAG", "filename", f.name, "error", err)
|
||||||
|
_ = notifyUser("RAG", "failed to unload file; error: "+err.Error())
|
||||||
|
app.QueueUpdate(func() {
|
||||||
|
pages.RemovePage(RAGPage)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = notifyUser("RAG", "file unloaded successfully")
|
||||||
|
app.QueueUpdate(func() {
|
||||||
|
pages.RemovePage(RAGPage)
|
||||||
|
})
|
||||||
}()
|
}()
|
||||||
return
|
return
|
||||||
case "delete":
|
case "delete":
|
||||||
fpath = path.Join(cfg.RAGDir, fpath)
|
fpath := path.Join(cfg.RAGDir, f.name)
|
||||||
if err := os.Remove(fpath); err != nil {
|
if err := os.Remove(fpath); err != nil {
|
||||||
logger.Error("failed to delete file", "filename", fpath, "error", err)
|
logger.Error("failed to delete file", "filename", fpath, "error", err)
|
||||||
return
|
return
|
||||||
@@ -383,114 +512,6 @@ func makeRAGTable(fileList []string) *tview.Flex {
|
|||||||
return ragflex
|
return ragflex
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeLoadedRAGTable(fileList []string) *tview.Flex {
|
|
||||||
actions := []string{"delete"}
|
|
||||||
rows, cols := len(fileList), len(actions)+1
|
|
||||||
// Add 1 extra row for the "exit" option at the top
|
|
||||||
fileTable := tview.NewTable().
|
|
||||||
SetBorders(true)
|
|
||||||
longStatusView := tview.NewTextView()
|
|
||||||
longStatusView.SetText("Loaded RAG files list")
|
|
||||||
longStatusView.SetBorder(true).SetTitle("status")
|
|
||||||
longStatusView.SetChangedFunc(func() {
|
|
||||||
app.Draw()
|
|
||||||
})
|
|
||||||
ragflex := tview.NewFlex().SetDirection(tview.FlexRow).
|
|
||||||
AddItem(longStatusView, 0, 10, false).
|
|
||||||
AddItem(fileTable, 0, 60, true)
|
|
||||||
// Add the exit option as the first row (row 0)
|
|
||||||
fileTable.SetCell(0, 0,
|
|
||||||
tview.NewTableCell("Exit Loaded Files manager").
|
|
||||||
SetTextColor(tcell.ColorWhite).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
fileTable.SetCell(0, 1,
|
|
||||||
tview.NewTableCell("(Close without action)").
|
|
||||||
SetTextColor(tcell.ColorGray).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
fileTable.SetCell(0, 2,
|
|
||||||
tview.NewTableCell("exit").
|
|
||||||
SetTextColor(tcell.ColorGray).
|
|
||||||
SetAlign(tview.AlignCenter))
|
|
||||||
// Add the file rows starting from row 1
|
|
||||||
for r := 0; r < rows; r++ {
|
|
||||||
for c := 0; c < cols; c++ {
|
|
||||||
color := tcell.ColorWhite
|
|
||||||
switch {
|
|
||||||
case c < 1:
|
|
||||||
fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0
|
|
||||||
tview.NewTableCell(fileList[r]).
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
case c == 1: // Action description column - not selectable
|
|
||||||
fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0
|
|
||||||
tview.NewTableCell("(Action)").
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
default: // Action button column - selectable
|
|
||||||
fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0
|
|
||||||
tview.NewTableCell(actions[c-1]).
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fileTable.Select(0, 0).
|
|
||||||
SetFixed(1, 1).
|
|
||||||
SetSelectable(true, false).
|
|
||||||
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
|
||||||
SetDoneFunc(func(key tcell.Key) {
|
|
||||||
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') || key == tcell.KeyCtrlX {
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}).SetSelectedFunc(func(row int, column int) {
|
|
||||||
// If user selects a non-actionable column (0 or 1), move to first action column (2)
|
|
||||||
if column <= 1 {
|
|
||||||
if fileTable.GetColumnCount() > 2 {
|
|
||||||
fileTable.Select(row, 2) // Select first action column
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tc := fileTable.GetCell(row, column)
|
|
||||||
// Check if the selected row is the exit row (row 0) - do this first to avoid index issues
|
|
||||||
if row == 0 {
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// For file rows, get the filename (row index - 1 because of the exit row at index 0)
|
|
||||||
fpath := fileList[row-1] // -1 to account for the exit row at index 0
|
|
||||||
switch tc.Text {
|
|
||||||
case "delete":
|
|
||||||
if err := ragger.RemoveFile(fpath); err != nil {
|
|
||||||
logger.Error("failed to delete file from RAG", "filename", fpath, "error", err)
|
|
||||||
longStatusView.SetText(fmt.Sprintf("Error deleting file: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := notifyUser("RAG file deleted", fpath+" was deleted from RAG system"); err != nil {
|
|
||||||
logger.Error("failed to send notification", "error", err)
|
|
||||||
}
|
|
||||||
longStatusView.SetText(fpath + " was deleted from RAG system")
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
// Add input capture to the flex container to handle 'x' key for closing
|
|
||||||
ragflex.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return event
|
|
||||||
})
|
|
||||||
return ragflex
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeAgentTable(agentList []string) *tview.Table {
|
func makeAgentTable(agentList []string) *tview.Table {
|
||||||
actions := []string{"filepath", "load"}
|
actions := []string{"filepath", "load"}
|
||||||
rows, cols := len(agentList), len(actions)+1
|
rows, cols := len(agentList), len(actions)+1
|
||||||
@@ -499,14 +520,14 @@ func makeAgentTable(agentList []string) *tview.Table {
|
|||||||
for r := 0; r < rows; r++ {
|
for r := 0; r < rows; r++ {
|
||||||
for c := 0; c < cols; c++ {
|
for c := 0; c < cols; c++ {
|
||||||
color := tcell.ColorWhite
|
color := tcell.ColorWhite
|
||||||
switch {
|
switch c {
|
||||||
case c < 1:
|
case 0:
|
||||||
chatActTable.SetCell(r, c,
|
chatActTable.SetCell(r, c,
|
||||||
tview.NewTableCell(agentList[r]).
|
tview.NewTableCell(agentList[r]).
|
||||||
SetTextColor(color).
|
SetTextColor(color).
|
||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
case c == 1:
|
case 1:
|
||||||
if actions[c-1] == "filepath" {
|
if actions[c-1] == "filepath" {
|
||||||
cc, ok := sysMap[agentList[r]]
|
cc, ok := sysMap[agentList[r]]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -533,7 +554,7 @@ func makeAgentTable(agentList []string) *tview.Table {
|
|||||||
}
|
}
|
||||||
chatActTable.Select(0, 0).
|
chatActTable.Select(0, 0).
|
||||||
SetFixed(1, 1).
|
SetFixed(1, 1).
|
||||||
SetSelectable(true, false).
|
SetSelectable(true, true).
|
||||||
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
||||||
SetDoneFunc(func(key tcell.Key) {
|
SetDoneFunc(func(key tcell.Key) {
|
||||||
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') {
|
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') {
|
||||||
@@ -549,6 +570,8 @@ func makeAgentTable(agentList []string) *tview.Table {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
tc := chatActTable.GetCell(row, column)
|
tc := chatActTable.GetCell(row, column)
|
||||||
|
tc.SetTextColor(tcell.ColorRed)
|
||||||
|
chatActTable.SetSelectable(false, false)
|
||||||
selected := agentList[row]
|
selected := agentList[row]
|
||||||
// notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
|
// notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
|
||||||
switch tc.Text {
|
switch tc.Text {
|
||||||
@@ -630,7 +653,7 @@ func makeCodeBlockTable(codeBlocks []string) *tview.Table {
|
|||||||
}
|
}
|
||||||
table.Select(0, 0).
|
table.Select(0, 0).
|
||||||
SetFixed(1, 1).
|
SetFixed(1, 1).
|
||||||
SetSelectable(true, false).
|
SetSelectable(true, true).
|
||||||
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
||||||
SetDoneFunc(func(key tcell.Key) {
|
SetDoneFunc(func(key tcell.Key) {
|
||||||
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') {
|
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') {
|
||||||
@@ -646,6 +669,8 @@ func makeCodeBlockTable(codeBlocks []string) *tview.Table {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
tc := table.GetCell(row, column)
|
tc := table.GetCell(row, column)
|
||||||
|
tc.SetTextColor(tcell.ColorRed)
|
||||||
|
table.SetSelectable(false, false)
|
||||||
selected := codeBlocks[row]
|
selected := codeBlocks[row]
|
||||||
// notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
|
// notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
|
||||||
switch tc.Text {
|
switch tc.Text {
|
||||||
@@ -702,7 +727,7 @@ func makeImportChatTable(filenames []string) *tview.Table {
|
|||||||
}
|
}
|
||||||
chatActTable.Select(0, 0).
|
chatActTable.Select(0, 0).
|
||||||
SetFixed(1, 1).
|
SetFixed(1, 1).
|
||||||
SetSelectable(true, false).
|
SetSelectable(true, true).
|
||||||
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
||||||
SetDoneFunc(func(key tcell.Key) {
|
SetDoneFunc(func(key tcell.Key) {
|
||||||
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') {
|
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') {
|
||||||
@@ -718,6 +743,8 @@ func makeImportChatTable(filenames []string) *tview.Table {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
tc := chatActTable.GetCell(row, column)
|
tc := chatActTable.GetCell(row, column)
|
||||||
|
tc.SetTextColor(tcell.ColorRed)
|
||||||
|
chatActTable.SetSelectable(false, false)
|
||||||
selected := filenames[row]
|
selected := filenames[row]
|
||||||
// notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
|
// notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
|
||||||
switch tc.Text {
|
switch tc.Text {
|
||||||
@@ -792,6 +819,7 @@ func makeFilePicker() *tview.Flex {
|
|||||||
// --- NEW: search state ---
|
// --- NEW: search state ---
|
||||||
searching := false
|
searching := false
|
||||||
searchQuery := ""
|
searchQuery := ""
|
||||||
|
searchInputMode := false
|
||||||
// Helper function to check if a file has an allowed extension from config
|
// Helper function to check if a file has an allowed extension from config
|
||||||
hasAllowedExtension := func(filename string) bool {
|
hasAllowedExtension := func(filename string) bool {
|
||||||
if cfg.FilePickerExts == "" {
|
if cfg.FilePickerExts == "" {
|
||||||
@@ -984,6 +1012,7 @@ func makeFilePicker() *tview.Flex {
|
|||||||
case tcell.KeyEsc:
|
case tcell.KeyEsc:
|
||||||
// Exit search, clear filter
|
// Exit search, clear filter
|
||||||
searching = false
|
searching = false
|
||||||
|
searchInputMode = false
|
||||||
searchQuery = ""
|
searchQuery = ""
|
||||||
refreshList(currentDisplayDir, "")
|
refreshList(currentDisplayDir, "")
|
||||||
return nil
|
return nil
|
||||||
@@ -993,16 +1022,80 @@ func makeFilePicker() *tview.Flex {
|
|||||||
refreshList(currentDisplayDir, searchQuery)
|
refreshList(currentDisplayDir, searchQuery)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case tcell.KeyRune:
|
case tcell.KeyEnter:
|
||||||
r := event.Rune()
|
// Exit search input mode and let normal processing handle selection
|
||||||
if r != 0 {
|
searchInputMode = false
|
||||||
searchQuery += string(r)
|
// Get the currently highlighted item in the list
|
||||||
refreshList(currentDisplayDir, searchQuery)
|
itemIndex := listView.GetCurrentItem()
|
||||||
|
if itemIndex >= 0 && itemIndex < listView.GetItemCount() {
|
||||||
|
itemText, _ := listView.GetItemText(itemIndex)
|
||||||
|
// Check for the exit option first
|
||||||
|
if strings.HasPrefix(itemText, "Exit file picker") {
|
||||||
|
pages.RemovePage(filePickerPage)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Extract the actual filename/directory name by removing the type info
|
||||||
|
actualItemName := itemText
|
||||||
|
if bracketPos := strings.Index(itemText, " ["); bracketPos != -1 {
|
||||||
|
actualItemName = itemText[:bracketPos]
|
||||||
|
}
|
||||||
|
// Check if it's a directory (ends with /)
|
||||||
|
if strings.HasSuffix(actualItemName, "/") {
|
||||||
|
var targetDir string
|
||||||
|
if strings.HasPrefix(actualItemName, "../") {
|
||||||
|
// Parent directory
|
||||||
|
targetDir = path.Dir(currentDisplayDir)
|
||||||
|
if targetDir == currentDisplayDir && currentDisplayDir == "/" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular subdirectory
|
||||||
|
dirName := strings.TrimSuffix(actualItemName, "/")
|
||||||
|
targetDir = path.Join(currentDisplayDir, dirName)
|
||||||
|
}
|
||||||
|
// Navigate – clear search
|
||||||
|
if cfg.ImagePreview && imgPreview != nil {
|
||||||
|
imgPreview.SetImage(nil)
|
||||||
|
}
|
||||||
|
searching = false
|
||||||
|
searchInputMode = false
|
||||||
|
searchQuery = ""
|
||||||
|
refreshList(targetDir, "")
|
||||||
|
dirStack = append(dirStack, targetDir)
|
||||||
|
currentStackPos = len(dirStack) - 1
|
||||||
|
statusView.SetText("Current: " + targetDir)
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
// It's a file
|
||||||
|
filePath := path.Join(currentDisplayDir, actualItemName)
|
||||||
|
if info, err := os.Stat(filePath); err == nil && !info.IsDir() {
|
||||||
|
if isImageFile(actualItemName) {
|
||||||
|
SetImageAttachment(filePath)
|
||||||
|
statusView.SetText("Image attached: " + filePath + " (will be sent with next message)")
|
||||||
|
pages.RemovePage(filePickerPage)
|
||||||
|
} else {
|
||||||
|
textArea.SetText(filePath, true)
|
||||||
|
app.SetFocus(textArea)
|
||||||
|
pages.RemovePage(filePickerPage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case tcell.KeyRune:
|
||||||
|
r := event.Rune()
|
||||||
|
if searchInputMode && r != 0 {
|
||||||
|
searchQuery += string(r)
|
||||||
|
refreshList(currentDisplayDir, searchQuery)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// If not in search input mode, pass through for navigation
|
||||||
|
return event
|
||||||
default:
|
default:
|
||||||
// Pass all other keys (arrows, Enter, etc.) to normal processing
|
// Exit search input mode but keep filter active for navigation
|
||||||
// This allows selecting items while still in search mode
|
searchInputMode = false
|
||||||
|
// Pass all other keys (arrows, etc.) to normal processing
|
||||||
return event
|
return event
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1030,6 +1123,7 @@ func makeFilePicker() *tview.Flex {
|
|||||||
if event.Rune() == '/' {
|
if event.Rune() == '/' {
|
||||||
// Enter search mode
|
// Enter search mode
|
||||||
searching = true
|
searching = true
|
||||||
|
searchInputMode = true
|
||||||
searchQuery = ""
|
searchQuery = ""
|
||||||
refreshList(currentDisplayDir, "")
|
refreshList(currentDisplayDir, "")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
140
tools.go
140
tools.go
@@ -16,6 +16,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gf-lt/rag"
|
||||||
|
|
||||||
"github.com/GrailFinder/searchagent/searcher"
|
"github.com/GrailFinder/searchagent/searcher"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -58,9 +60,9 @@ Your current tools:
|
|||||||
"when_to_use": "when asked to search the web for information; returns clean summary without html,css and other web elements; limit is optional (default 3)"
|
"when_to_use": "when asked to search the web for information; returns clean summary without html,css and other web elements; limit is optional (default 3)"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name":"websearch_raw",
|
"name":"rag_search",
|
||||||
"args": ["query", "limit"],
|
"args": ["query", "limit"],
|
||||||
"when_to_use": "when asked to search the web for information; returns raw data as is without processing; limit is optional (default 3)"
|
"when_to_use": "when asked to search the local document database for information; performs query refinement, semantic search, reranking, and synthesis; returns clean summary with sources; limit is optional (default 3)"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name":"read_url",
|
"name":"read_url",
|
||||||
@@ -146,6 +148,7 @@ under the topic: Adam's number is stored:
|
|||||||
After that you are free to respond to the user.
|
After that you are free to respond to the user.
|
||||||
`
|
`
|
||||||
webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.`
|
webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.`
|
||||||
|
ragSearchSysPrompt = `Synthesize the document search results, extracting key information and presenting a concise answer. Provide sources and document IDs where relevant.`
|
||||||
readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.`
|
readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.`
|
||||||
summarySysPrompt = `Please provide a concise summary of the following conversation. Focus on key points, decisions, and actions. Provide only the summary, no additional commentary.`
|
summarySysPrompt = `Please provide a concise summary of the following conversation. Focus on key points, decisions, and actions. Provide only the summary, no additional commentary.`
|
||||||
basicCard = &models.CharCard{
|
basicCard = &models.CharCard{
|
||||||
@@ -170,6 +173,9 @@ func init() {
|
|||||||
panic("failed to init seachagent; error: " + err.Error())
|
panic("failed to init seachagent; error: " + err.Error())
|
||||||
}
|
}
|
||||||
WebSearcher = sa
|
WebSearcher = sa
|
||||||
|
if err := rag.Init(cfg, logger, store); err != nil {
|
||||||
|
logger.Warn("failed to init rag; rag_search tool will not be available", "error", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getWebAgentClient returns a singleton AgentClient for web agents.
|
// getWebAgentClient returns a singleton AgentClient for web agents.
|
||||||
@@ -196,6 +202,8 @@ func getWebAgentClient() *agent.AgentClient {
|
|||||||
func registerWebAgents() {
|
func registerWebAgents() {
|
||||||
webAgentsOnce.Do(func() {
|
webAgentsOnce.Do(func() {
|
||||||
client := getWebAgentClient()
|
client := getWebAgentClient()
|
||||||
|
// Register rag_search agent
|
||||||
|
agent.Register("rag_search", agent.NewWebAgentB(client, ragSearchSysPrompt))
|
||||||
// Register websearch agent
|
// Register websearch agent
|
||||||
agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt))
|
agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt))
|
||||||
// Register read_url agent
|
// Register read_url agent
|
||||||
@@ -239,6 +247,45 @@ func websearch(args map[string]string) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rag search (searches local document database)
|
||||||
|
func ragsearch(args map[string]string) []byte {
|
||||||
|
query, ok := args["query"]
|
||||||
|
if !ok || query == "" {
|
||||||
|
msg := "query not provided to rag_search tool"
|
||||||
|
logger.Error(msg)
|
||||||
|
return []byte(msg)
|
||||||
|
}
|
||||||
|
limitS, ok := args["limit"]
|
||||||
|
if !ok || limitS == "" {
|
||||||
|
limitS = "3"
|
||||||
|
}
|
||||||
|
limit, err := strconv.Atoi(limitS)
|
||||||
|
if err != nil || limit == 0 {
|
||||||
|
logger.Warn("ragsearch limit; passed bad value; setting to default (3)",
|
||||||
|
"limit_arg", limitS, "error", err)
|
||||||
|
limit = 3
|
||||||
|
}
|
||||||
|
ragInstance := rag.GetInstance()
|
||||||
|
if ragInstance == nil {
|
||||||
|
msg := "rag not initialized; rag_search tool is not available"
|
||||||
|
logger.Error(msg)
|
||||||
|
return []byte(msg)
|
||||||
|
}
|
||||||
|
results, err := ragInstance.Search(query, limit)
|
||||||
|
if err != nil {
|
||||||
|
msg := "rag search failed; error: " + err.Error()
|
||||||
|
logger.Error(msg)
|
||||||
|
return []byte(msg)
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(results)
|
||||||
|
if err != nil {
|
||||||
|
msg := "failed to marshal rag search result; error: " + err.Error()
|
||||||
|
logger.Error(msg)
|
||||||
|
return []byte(msg)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// web search raw (returns raw data without processing)
|
// web search raw (returns raw data without processing)
|
||||||
func websearchRaw(args map[string]string) []byte {
|
func websearchRaw(args map[string]string) []byte {
|
||||||
// make http request return bytes
|
// make http request return bytes
|
||||||
@@ -369,7 +416,6 @@ func recallTopics(args map[string]string) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// File Manipulation Tools
|
// File Manipulation Tools
|
||||||
|
|
||||||
func fileCreate(args map[string]string) []byte {
|
func fileCreate(args map[string]string) []byte {
|
||||||
path, ok := args["path"]
|
path, ok := args["path"]
|
||||||
if !ok || path == "" {
|
if !ok || path == "" {
|
||||||
@@ -377,20 +423,16 @@ func fileCreate(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
content, ok := args["content"]
|
content, ok := args["content"]
|
||||||
if !ok {
|
if !ok {
|
||||||
content = ""
|
content = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeStringToFile(path, content); err != nil {
|
if err := writeStringToFile(path, content); err != nil {
|
||||||
msg := "failed to create file; error: " + err.Error()
|
msg := "failed to create file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := "file created successfully at " + path
|
msg := "file created successfully at " + path
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -402,16 +444,13 @@ func fileRead(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
content, err := readStringFromFile(path)
|
content, err := readStringFromFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to read file; error: " + err.Error()
|
msg := "failed to read file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"content": content,
|
"content": content,
|
||||||
"path": path,
|
"path": path,
|
||||||
@@ -422,7 +461,6 @@ func fileRead(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,15 +513,12 @@ func fileDelete(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
if err := removeFile(path); err != nil {
|
if err := removeFile(path); err != nil {
|
||||||
msg := "failed to delete file; error: " + err.Error()
|
msg := "failed to delete file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := "file deleted successfully at " + path
|
msg := "file deleted successfully at " + path
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -496,7 +531,6 @@ func fileMove(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
src = resolvePath(src)
|
src = resolvePath(src)
|
||||||
|
|
||||||
dst, ok := args["dst"]
|
dst, ok := args["dst"]
|
||||||
if !ok || dst == "" {
|
if !ok || dst == "" {
|
||||||
msg := "destination path not provided to file_move tool"
|
msg := "destination path not provided to file_move tool"
|
||||||
@@ -504,13 +538,11 @@ func fileMove(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
dst = resolvePath(dst)
|
dst = resolvePath(dst)
|
||||||
|
|
||||||
if err := moveFile(src, dst); err != nil {
|
if err := moveFile(src, dst); err != nil {
|
||||||
msg := "failed to move file; error: " + err.Error()
|
msg := "failed to move file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := fmt.Sprintf("file moved successfully from %s to %s", src, dst)
|
msg := fmt.Sprintf("file moved successfully from %s to %s", src, dst)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -523,7 +555,6 @@ func fileCopy(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
src = resolvePath(src)
|
src = resolvePath(src)
|
||||||
|
|
||||||
dst, ok := args["dst"]
|
dst, ok := args["dst"]
|
||||||
if !ok || dst == "" {
|
if !ok || dst == "" {
|
||||||
msg := "destination path not provided to file_copy tool"
|
msg := "destination path not provided to file_copy tool"
|
||||||
@@ -531,13 +562,11 @@ func fileCopy(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
dst = resolvePath(dst)
|
dst = resolvePath(dst)
|
||||||
|
|
||||||
if err := copyFile(src, dst); err != nil {
|
if err := copyFile(src, dst); err != nil {
|
||||||
msg := "failed to copy file; error: " + err.Error()
|
msg := "failed to copy file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := fmt.Sprintf("file copied successfully from %s to %s", src, dst)
|
msg := fmt.Sprintf("file copied successfully from %s to %s", src, dst)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -547,16 +576,13 @@ func fileList(args map[string]string) []byte {
|
|||||||
if !ok || path == "" {
|
if !ok || path == "" {
|
||||||
path = "." // default to current directory
|
path = "." // default to current directory
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
files, err := listDirectory(path)
|
files, err := listDirectory(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to list directory; error: " + err.Error()
|
msg := "failed to list directory; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]interface{}{
|
result := map[string]interface{}{
|
||||||
"directory": path,
|
"directory": path,
|
||||||
"files": files,
|
"files": files,
|
||||||
@@ -567,12 +593,10 @@ func fileList(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions for file operations
|
// Helper functions for file operations
|
||||||
|
|
||||||
func resolvePath(p string) string {
|
func resolvePath(p string) string {
|
||||||
if filepath.IsAbs(p) {
|
if filepath.IsAbs(p) {
|
||||||
return p
|
return p
|
||||||
@@ -598,7 +622,6 @@ func appendStringToFile(filename string, data string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
_, err = file.WriteString(data)
|
_, err = file.WriteString(data)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -622,13 +645,11 @@ func copyFile(src, dst string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer srcFile.Close()
|
defer srcFile.Close()
|
||||||
|
|
||||||
dstFile, err := os.Create(dst)
|
dstFile, err := os.Create(dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer dstFile.Close()
|
defer dstFile.Close()
|
||||||
|
|
||||||
_, err = io.Copy(dstFile, srcFile)
|
_, err = io.Copy(dstFile, srcFile)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -647,7 +668,6 @@ func listDirectory(path string) ([]string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if entry.IsDir() {
|
if entry.IsDir() {
|
||||||
@@ -656,12 +676,10 @@ func listDirectory(path string) ([]string, error) {
|
|||||||
files = append(files, entry.Name())
|
files = append(files, entry.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return files, nil
|
return files, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Command Execution Tool
|
// Command Execution Tool
|
||||||
|
|
||||||
func executeCommand(args map[string]string) []byte {
|
func executeCommand(args map[string]string) []byte {
|
||||||
command, ok := args["command"]
|
command, ok := args["command"]
|
||||||
if !ok || command == "" {
|
if !ok || command == "" {
|
||||||
@@ -669,7 +687,6 @@ func executeCommand(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get arguments - handle both single arg and multiple args
|
// Get arguments - handle both single arg and multiple args
|
||||||
var cmdArgs []string
|
var cmdArgs []string
|
||||||
if args["args"] != "" {
|
if args["args"] != "" {
|
||||||
@@ -688,43 +705,36 @@ func executeCommand(args map[string]string) []byte {
|
|||||||
argNum++
|
argNum++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isCommandAllowed(command, cmdArgs...) {
|
if !isCommandAllowed(command, cmdArgs...) {
|
||||||
msg := fmt.Sprintf("command '%s' is not allowed", command)
|
msg := fmt.Sprintf("command '%s' is not allowed", command)
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute with timeout for safety
|
// Execute with timeout for safety
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
cmd := exec.CommandContext(ctx, command, cmdArgs...)
|
cmd := exec.CommandContext(ctx, command, cmdArgs...)
|
||||||
|
|
||||||
output, err := cmd.CombinedOutput()
|
output, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := fmt.Sprintf("command '%s' failed; error: %v; output: %s", command, err, string(output))
|
msg := fmt.Sprintf("command '%s' failed; error: %v; output: %s", command, err, string(output))
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if output is empty and return success message
|
// Check if output is empty and return success message
|
||||||
if len(output) == 0 {
|
if len(output) == 0 {
|
||||||
successMsg := fmt.Sprintf("command '%s %s' executed successfully and exited with code 0", command, strings.Join(cmdArgs, " "))
|
successMsg := fmt.Sprintf("command '%s %s' executed successfully and exited with code 0", command, strings.Join(cmdArgs, " "))
|
||||||
return []byte(successMsg)
|
return []byte(successMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions for command execution
|
// Helper functions for command execution
|
||||||
|
|
||||||
// Todo structure
|
// Todo structure
|
||||||
type TodoItem struct {
|
type TodoItem struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Task string `json:"task"`
|
Task string `json:"task"`
|
||||||
Status string `json:"status"` // "pending", "in_progress", "completed"
|
Status string `json:"status"` // "pending", "in_progress", "completed"
|
||||||
}
|
}
|
||||||
|
|
||||||
type TodoList struct {
|
type TodoList struct {
|
||||||
Items []TodoItem `json:"items"`
|
Items []TodoItem `json:"items"`
|
||||||
}
|
}
|
||||||
@@ -742,32 +752,26 @@ func todoCreate(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate simple ID
|
// Generate simple ID
|
||||||
id := fmt.Sprintf("todo_%d", len(globalTodoList.Items)+1)
|
id := fmt.Sprintf("todo_%d", len(globalTodoList.Items)+1)
|
||||||
|
|
||||||
newItem := TodoItem{
|
newItem := TodoItem{
|
||||||
ID: id,
|
ID: id,
|
||||||
Task: task,
|
Task: task,
|
||||||
Status: "pending",
|
Status: "pending",
|
||||||
}
|
}
|
||||||
|
|
||||||
globalTodoList.Items = append(globalTodoList.Items, newItem)
|
globalTodoList.Items = append(globalTodoList.Items, newItem)
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"message": "todo created successfully",
|
"message": "todo created successfully",
|
||||||
"id": id,
|
"id": id,
|
||||||
"task": task,
|
"task": task,
|
||||||
"status": "pending",
|
"status": "pending",
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, err := json.Marshal(result)
|
jsonResult, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to marshal result; error: " + err.Error()
|
msg := "failed to marshal result; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -801,7 +805,6 @@ func todoRead(args map[string]string) []byte {
|
|||||||
}
|
}
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return all todos if no ID specified
|
// Return all todos if no ID specified
|
||||||
result := map[string]interface{}{
|
result := map[string]interface{}{
|
||||||
"todos": globalTodoList.Items,
|
"todos": globalTodoList.Items,
|
||||||
@@ -812,7 +815,6 @@ func todoRead(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -823,16 +825,13 @@ func todoUpdate(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
task, taskOk := args["task"]
|
task, taskOk := args["task"]
|
||||||
status, statusOk := args["status"]
|
status, statusOk := args["status"]
|
||||||
|
|
||||||
if !taskOk && !statusOk {
|
if !taskOk && !statusOk {
|
||||||
msg := "neither task nor status provided to todo_update tool"
|
msg := "neither task nor status provided to todo_update tool"
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find and update the todo
|
// Find and update the todo
|
||||||
for i, item := range globalTodoList.Items {
|
for i, item := range globalTodoList.Items {
|
||||||
if item.ID == id {
|
if item.ID == id {
|
||||||
@@ -856,23 +855,19 @@ func todoUpdate(args map[string]string) []byte {
|
|||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"message": "todo updated successfully",
|
"message": "todo updated successfully",
|
||||||
"id": id,
|
"id": id,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, err := json.Marshal(result)
|
jsonResult, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to marshal result; error: " + err.Error()
|
msg := "failed to marshal result; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID not found
|
// ID not found
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"error": fmt.Sprintf("todo with id %s not found", id),
|
"error": fmt.Sprintf("todo with id %s not found", id),
|
||||||
@@ -893,29 +888,24 @@ func todoDelete(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find and remove the todo
|
// Find and remove the todo
|
||||||
for i, item := range globalTodoList.Items {
|
for i, item := range globalTodoList.Items {
|
||||||
if item.ID == id {
|
if item.ID == id {
|
||||||
// Remove item from slice
|
// Remove item from slice
|
||||||
globalTodoList.Items = append(globalTodoList.Items[:i], globalTodoList.Items[i+1:]...)
|
globalTodoList.Items = append(globalTodoList.Items[:i], globalTodoList.Items[i+1:]...)
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"message": "todo deleted successfully",
|
"message": "todo deleted successfully",
|
||||||
"id": id,
|
"id": id,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, err := json.Marshal(result)
|
jsonResult, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to marshal result; error: " + err.Error()
|
msg := "failed to marshal result; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID not found
|
// ID not found
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"error": fmt.Sprintf("todo with id %s not found", id),
|
"error": fmt.Sprintf("todo with id %s not found", id),
|
||||||
@@ -997,6 +987,7 @@ var fnMap = map[string]fnSig{
|
|||||||
"recall": recall,
|
"recall": recall,
|
||||||
"recall_topics": recallTopics,
|
"recall_topics": recallTopics,
|
||||||
"memorise": memorise,
|
"memorise": memorise,
|
||||||
|
"rag_search": ragsearch,
|
||||||
"websearch": websearch,
|
"websearch": websearch,
|
||||||
"websearch_raw": websearchRaw,
|
"websearch_raw": websearchRaw,
|
||||||
"read_url": readURL,
|
"read_url": readURL,
|
||||||
@@ -1033,6 +1024,28 @@ func callToolWithAgent(name string, args map[string]string) []byte {
|
|||||||
|
|
||||||
// openai style def
|
// openai style def
|
||||||
var baseTools = []models.Tool{
|
var baseTools = []models.Tool{
|
||||||
|
// rag_search
|
||||||
|
models.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: models.ToolFunc{
|
||||||
|
Name: "rag_search",
|
||||||
|
Description: "Search local document database given query, limit of sources (default 3). Performs query refinement, semantic search, reranking, and synthesis.",
|
||||||
|
Parameters: models.ToolFuncParams{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"query", "limit"},
|
||||||
|
Properties: map[string]models.ToolArgProps{
|
||||||
|
"query": models.ToolArgProps{
|
||||||
|
Type: "string",
|
||||||
|
Description: "search query",
|
||||||
|
},
|
||||||
|
"limit": models.ToolArgProps{
|
||||||
|
Type: "string",
|
||||||
|
Description: "limit of the document results",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
// websearch
|
// websearch
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1166,7 +1179,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_create
|
// file_create
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1189,7 +1201,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_read
|
// file_read
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1208,7 +1219,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_write
|
// file_write
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1231,7 +1241,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_write_append
|
// file_write_append
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1254,7 +1263,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_delete
|
// file_delete
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1273,7 +1281,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_move
|
// file_move
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1296,7 +1303,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_copy
|
// file_copy
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1319,7 +1325,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_list
|
// file_list
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1338,7 +1343,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// execute_command
|
// execute_command
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
|
|||||||
66
tui.go
66
tui.go
@@ -15,6 +15,11 @@ import (
|
|||||||
"github.com/rivo/tview"
|
"github.com/rivo/tview"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func isFullScreenPageActive() bool {
|
||||||
|
name, _ := pages.GetFrontPage()
|
||||||
|
return name != "main"
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
app *tview.Application
|
app *tview.Application
|
||||||
pages *tview.Pages
|
pages *tview.Pages
|
||||||
@@ -259,7 +264,7 @@ func init() {
|
|||||||
pages.RemovePage(editMsgPage)
|
pages.RemovePage(editMsgPage)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
chatBody.Messages[selectedIndex].Content = editedMsg
|
chatBody.Messages[selectedIndex].SetText(editedMsg)
|
||||||
// change textarea
|
// change textarea
|
||||||
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys))
|
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys))
|
||||||
pages.RemovePage(editMsgPage)
|
pages.RemovePage(editMsgPage)
|
||||||
@@ -347,13 +352,14 @@ func init() {
|
|||||||
case editMode:
|
case editMode:
|
||||||
hideIndexBar() // Hide overlay first
|
hideIndexBar() // Hide overlay first
|
||||||
pages.AddPage(editMsgPage, editArea, true, true)
|
pages.AddPage(editMsgPage, editArea, true, true)
|
||||||
editArea.SetText(m.Content, true)
|
editArea.SetText(m.GetText(), true)
|
||||||
default:
|
default:
|
||||||
if err := copyToClipboard(m.Content); err != nil {
|
msgText := m.GetText()
|
||||||
|
if err := copyToClipboard(msgText); err != nil {
|
||||||
logger.Error("failed to copy to clipboard", "error", err)
|
logger.Error("failed to copy to clipboard", "error", err)
|
||||||
}
|
}
|
||||||
previewLen := min(30, len(m.Content))
|
previewLen := min(30, len(msgText))
|
||||||
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", m.Content[:previewLen])
|
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", msgText[:previewLen])
|
||||||
if err := notifyUser("copied", notification); err != nil {
|
if err := notifyUser("copied", notification); err != nil {
|
||||||
logger.Error("failed to send notification", "error", err)
|
logger.Error("failed to send notification", "error", err)
|
||||||
}
|
}
|
||||||
@@ -525,6 +531,9 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'i' && event.Modifiers()&tcell.ModAlt != 0 {
|
if event.Key() == tcell.KeyRune && event.Rune() == 'i' && event.Modifiers()&tcell.ModAlt != 0 {
|
||||||
|
if isFullScreenPageActive() {
|
||||||
|
return event
|
||||||
|
}
|
||||||
showColorschemeSelectionPopup()
|
showColorschemeSelectionPopup()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -640,11 +649,12 @@ func init() {
|
|||||||
// copy msg to clipboard
|
// copy msg to clipboard
|
||||||
editMode = false
|
editMode = false
|
||||||
m := chatBody.Messages[len(chatBody.Messages)-1]
|
m := chatBody.Messages[len(chatBody.Messages)-1]
|
||||||
if err := copyToClipboard(m.Content); err != nil {
|
msgText := m.GetText()
|
||||||
|
if err := copyToClipboard(msgText); err != nil {
|
||||||
logger.Error("failed to copy to clipboard", "error", err)
|
logger.Error("failed to copy to clipboard", "error", err)
|
||||||
}
|
}
|
||||||
previewLen := min(30, len(m.Content))
|
previewLen := min(30, len(msgText))
|
||||||
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", m.Content[:previewLen])
|
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", msgText[:previewLen])
|
||||||
if err := notifyUser("copied", notification); err != nil {
|
if err := notifyUser("copied", notification); err != nil {
|
||||||
logger.Error("failed to send notification", "error", err)
|
logger.Error("failed to send notification", "error", err)
|
||||||
}
|
}
|
||||||
@@ -731,6 +741,9 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyCtrlL {
|
if event.Key() == tcell.KeyCtrlL {
|
||||||
|
if isFullScreenPageActive() {
|
||||||
|
return event
|
||||||
|
}
|
||||||
// Show model selection popup instead of rotating models
|
// Show model selection popup instead of rotating models
|
||||||
showModelSelectionPopup()
|
showModelSelectionPopup()
|
||||||
return nil
|
return nil
|
||||||
@@ -744,6 +757,9 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyCtrlV {
|
if event.Key() == tcell.KeyCtrlV {
|
||||||
|
if isFullScreenPageActive() {
|
||||||
|
return event
|
||||||
|
}
|
||||||
// Show API link selection popup instead of rotating APIs
|
// Show API link selection popup instead of rotating APIs
|
||||||
showAPILinkSelectionPopup()
|
showAPILinkSelectionPopup()
|
||||||
return nil
|
return nil
|
||||||
@@ -833,7 +849,7 @@ func init() {
|
|||||||
// Stop any currently playing TTS first
|
// Stop any currently playing TTS first
|
||||||
TTSDoneChan <- true
|
TTSDoneChan <- true
|
||||||
lastMsg := chatBody.Messages[len(chatBody.Messages)-1]
|
lastMsg := chatBody.Messages[len(chatBody.Messages)-1]
|
||||||
cleanedText := models.CleanText(lastMsg.Content)
|
cleanedText := models.CleanText(lastMsg.GetText())
|
||||||
if cleanedText != "" {
|
if cleanedText != "" {
|
||||||
// nolint: errcheck
|
// nolint: errcheck
|
||||||
go orator.Speak(cleanedText)
|
go orator.Speak(cleanedText)
|
||||||
@@ -850,11 +866,17 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyCtrlQ {
|
if event.Key() == tcell.KeyCtrlQ {
|
||||||
|
if isFullScreenPageActive() {
|
||||||
|
return event
|
||||||
|
}
|
||||||
// Show user role selection popup instead of cycling through roles
|
// Show user role selection popup instead of cycling through roles
|
||||||
showUserRoleSelectionPopup()
|
showUserRoleSelectionPopup()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyCtrlX {
|
if event.Key() == tcell.KeyCtrlX {
|
||||||
|
if isFullScreenPageActive() {
|
||||||
|
return event
|
||||||
|
}
|
||||||
// Show bot role selection popup instead of cycling through roles
|
// Show bot role selection popup instead of cycling through roles
|
||||||
showBotRoleSelectionPopup()
|
showBotRoleSelectionPopup()
|
||||||
return nil
|
return nil
|
||||||
@@ -893,6 +915,7 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Get files from ragdir
|
||||||
fileList := []string{}
|
fileList := []string{}
|
||||||
for _, f := range files {
|
for _, f := range files {
|
||||||
if f.IsDir() {
|
if f.IsDir() {
|
||||||
@@ -900,22 +923,14 @@ func init() {
|
|||||||
}
|
}
|
||||||
fileList = append(fileList, f.Name())
|
fileList = append(fileList, f.Name())
|
||||||
}
|
}
|
||||||
chatRAGTable := makeRAGTable(fileList)
|
// Get loaded files from vector DB
|
||||||
pages.AddPage(RAGPage, chatRAGTable, true, true)
|
loadedFiles, err := ragger.ListLoaded()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if event.Key() == tcell.KeyCtrlY { // Use Ctrl+Y to list loaded RAG files
|
|
||||||
// List files already loaded into the RAG system
|
|
||||||
fileList, err := ragger.ListLoaded()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to list loaded RAG files", "error", err)
|
logger.Error("failed to list loaded RAG files", "error", err)
|
||||||
if notifyerr := notifyUser("failed to list RAG files", err.Error()); notifyerr != nil {
|
loadedFiles = []string{} // Continue with empty list on error
|
||||||
logger.Error("failed to send notification", "error", notifyerr)
|
|
||||||
}
|
}
|
||||||
return nil
|
chatRAGTable := makeRAGTable(fileList, loadedFiles)
|
||||||
}
|
pages.AddPage(RAGPage, chatRAGTable, true, true)
|
||||||
chatLoadedRAGTable := makeLoadedRAGTable(fileList)
|
|
||||||
pages.AddPage(RAGLoadedPage, chatLoadedRAGTable, true, true)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if event.Key() == tcell.KeyRune && event.Modifiers() == tcell.ModAlt && event.Rune() == '1' {
|
if event.Key() == tcell.KeyRune && event.Modifiers() == tcell.ModAlt && event.Rune() == '1' {
|
||||||
@@ -975,13 +990,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
// go chatRound(msgText, persona, textView, false, false)
|
// go chatRound(msgText, persona, textView, false, false)
|
||||||
chatRoundChan <- &models.ChatRoundReq{Role: persona, UserMsg: msgText}
|
chatRoundChan <- &models.ChatRoundReq{Role: persona, UserMsg: msgText}
|
||||||
// Also clear any image attachment after sending the message
|
|
||||||
go func() {
|
|
||||||
// Wait a short moment for the message to be processed, then clear the image attachment
|
|
||||||
// This allows the image to be sent with the current message if it was attached
|
|
||||||
// But clears it for the next message
|
|
||||||
ClearImageAttachment()
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user