Compare commits
7 Commits
feat/ragto
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d51c5d0f3 | ||
|
|
b97cd67d72 | ||
|
|
888c9fec65 | ||
|
|
4f07994bdc | ||
|
|
776fd7a2c4 | ||
|
|
9c6b0dc1fa | ||
|
|
9f51bd3853 |
7
Makefile
7
Makefile
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: setconfig run lint setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run installdelve checkdelve
|
.PHONY: setconfig run lint install-linters setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run installdelve checkdelve
|
||||||
|
|
||||||
run: setconfig
|
run: setconfig
|
||||||
go build -tags extra -o gf-lt && ./gf-lt
|
go build -tags extra -o gf-lt && ./gf-lt
|
||||||
@@ -21,8 +21,11 @@ installdelve:
|
|||||||
checkdelve:
|
checkdelve:
|
||||||
which dlv &>/dev/null || installdelve
|
which dlv &>/dev/null || installdelve
|
||||||
|
|
||||||
|
install-linters: ## Install additional linters (noblanks)
|
||||||
|
go install github.com/GrailFinder/noblanks-linter/cmd/noblanks@latest
|
||||||
|
|
||||||
lint: ## Run linters. Use make install-linters first.
|
lint: ## Run linters. Use make install-linters first.
|
||||||
golangci-lint run -c .golangci.yml ./...
|
golangci-lint run -c .golangci.yml ./...; noblanks ./...
|
||||||
|
|
||||||
# Whisper STT Setup (in batteries directory)
|
# Whisper STT Setup (in batteries directory)
|
||||||
setup-whisper: build-whisper download-whisper-model
|
setup-whisper: build-whisper download-whisper-model
|
||||||
|
|||||||
@@ -140,7 +140,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
ag.log.Error("failed to read request body", "error", err)
|
ag.log.Error("failed to read request body", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
|
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("failed to create request", "error", err)
|
ag.log.Error("failed to create request", "error", err)
|
||||||
@@ -150,22 +149,18 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
|
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
|
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
responseBytes, err := io.ReadAll(resp.Body)
|
responseBytes, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("failed to read response", "error", err)
|
ag.log.Error("failed to read response", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
|
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
|
||||||
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
|
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
|
||||||
@@ -178,7 +173,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
// Return raw response as fallback
|
// Return raw response as fallback
|
||||||
return responseBytes, nil
|
return responseBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte(text), nil
|
return []byte(text), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
76
bot.go
76
bot.go
@@ -23,8 +23,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/neurosnap/sentences/english"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -119,7 +117,7 @@ func processMessageTag(msg *models.RoleMsg) *models.RoleMsg {
|
|||||||
}
|
}
|
||||||
// If KnownTo already set, assume tag already processed (content cleaned).
|
// If KnownTo already set, assume tag already processed (content cleaned).
|
||||||
// However, we still check for new tags (maybe added later).
|
// However, we still check for new tags (maybe added later).
|
||||||
knownTo := parseKnownToTag(msg.Content)
|
knownTo := parseKnownToTag(msg.GetText())
|
||||||
// If tag found, replace KnownTo with new list (merge with existing?)
|
// If tag found, replace KnownTo with new list (merge with existing?)
|
||||||
// For simplicity, if knownTo is not nil, replace.
|
// For simplicity, if knownTo is not nil, replace.
|
||||||
if knownTo == nil {
|
if knownTo == nil {
|
||||||
@@ -753,62 +751,6 @@ func sendMsgToLLM(body io.Reader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatRagUse(qText string) (string, error) {
|
|
||||||
logger.Debug("Starting RAG query", "original_query", qText)
|
|
||||||
tokenizer, err := english.NewSentenceTokenizer(nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to create sentence tokenizer", "error", err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// this where llm should find the questions in text and ask them
|
|
||||||
questionsS := tokenizer.Tokenize(qText)
|
|
||||||
questions := make([]string, len(questionsS))
|
|
||||||
for i, q := range questionsS {
|
|
||||||
questions[i] = q.Text
|
|
||||||
logger.Debug("RAG question extracted", "index", i, "question", q.Text)
|
|
||||||
}
|
|
||||||
if len(questions) == 0 {
|
|
||||||
logger.Warn("No questions extracted from query text", "query", qText)
|
|
||||||
return "No related results from RAG vector storage.", nil
|
|
||||||
}
|
|
||||||
respVecs := []models.VectorRow{}
|
|
||||||
for i, q := range questions {
|
|
||||||
logger.Debug("Processing RAG question", "index", i, "question", q)
|
|
||||||
emb, err := ragger.LineToVector(q)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to get embeddings for RAG", "error", err, "index", i, "question", q)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Debug("Got embeddings for question", "index", i, "question_len", len(q), "embedding_len", len(emb))
|
|
||||||
// Create EmbeddingResp struct for the search
|
|
||||||
embeddingResp := &models.EmbeddingResp{
|
|
||||||
Embedding: emb,
|
|
||||||
Index: 0, // Not used in search but required for the struct
|
|
||||||
}
|
|
||||||
vecs, err := ragger.SearchEmb(embeddingResp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to query embeddings in RAG", "error", err, "index", i, "question", q)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Debug("RAG search returned vectors", "index", i, "question", q, "vector_count", len(vecs))
|
|
||||||
respVecs = append(respVecs, vecs...)
|
|
||||||
}
|
|
||||||
// get raw text
|
|
||||||
resps := []string{}
|
|
||||||
logger.Debug("RAG query final results", "total_vecs_found", len(respVecs))
|
|
||||||
for _, rv := range respVecs {
|
|
||||||
resps = append(resps, rv.RawText)
|
|
||||||
logger.Debug("RAG result", "slug", rv.Slug, "filename", rv.FileName, "raw_text_len", len(rv.RawText))
|
|
||||||
}
|
|
||||||
if len(resps) == 0 {
|
|
||||||
logger.Info("No RAG results found for query", "original_query", qText, "question_count", len(questions))
|
|
||||||
return "No related results from RAG vector storage.", nil
|
|
||||||
}
|
|
||||||
result := strings.Join(resps, "\n")
|
|
||||||
logger.Debug("RAG query completed", "result_len", len(result), "response_count", len(resps))
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func roleToIcon(role string) string {
|
func roleToIcon(role string) string {
|
||||||
return "<" + role + ">: "
|
return "<" + role + ">: "
|
||||||
}
|
}
|
||||||
@@ -838,11 +780,12 @@ func showSpinner() {
|
|||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
spin := i % len(spinners)
|
spin := i % len(spinners)
|
||||||
app.QueueUpdateDraw(func() {
|
app.QueueUpdateDraw(func() {
|
||||||
if toolRunningMode {
|
switch {
|
||||||
|
case toolRunningMode:
|
||||||
textArea.SetTitle(spinners[spin] + " tool")
|
textArea.SetTitle(spinners[spin] + " tool")
|
||||||
} else if botRespMode {
|
case botRespMode:
|
||||||
textArea.SetTitle(spinners[spin] + " " + botPersona)
|
textArea.SetTitle(spinners[spin] + " " + botPersona)
|
||||||
} else {
|
default:
|
||||||
textArea.SetTitle(spinners[spin] + " input")
|
textArea.SetTitle(spinners[spin] + " input")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -1303,12 +1246,9 @@ func removeThinking(chatBody *models.ChatBody) {
|
|||||||
if msg.Role == cfg.ToolRole {
|
if msg.Role == cfg.ToolRole {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// find thinking and remove it
|
// find thinking and remove it - use SetText to preserve ContentParts
|
||||||
rm := models.RoleMsg{
|
msg.SetText(thinkRE.ReplaceAllString(msg.GetText(), ""))
|
||||||
Role: msg.Role,
|
msgs = append(msgs, msg)
|
||||||
Content: thinkRE.ReplaceAllString(msg.Content, ""),
|
|
||||||
}
|
|
||||||
msgs = append(msgs, rm)
|
|
||||||
}
|
}
|
||||||
chatBody.Messages = msgs
|
chatBody.Messages = msgs
|
||||||
}
|
}
|
||||||
|
|||||||
34
bot_test.go
34
bot_test.go
@@ -1,12 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gf-lt/config"
|
"gf-lt/config"
|
||||||
"gf-lt/models"
|
"gf-lt/models"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
||||||
// Mock config for testing
|
// Mock config for testing
|
||||||
testCfg := &config.Config{
|
testCfg := &config.Config{
|
||||||
@@ -14,7 +12,6 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
|||||||
WriteNextMsgAsCompletionAgent: "",
|
WriteNextMsgAsCompletionAgent: "",
|
||||||
}
|
}
|
||||||
cfg = testCfg
|
cfg = testCfg
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input []models.RoleMsg
|
input []models.RoleMsg
|
||||||
@@ -114,38 +111,31 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := consolidateAssistantMessages(tt.input)
|
result := consolidateAssistantMessages(tt.input)
|
||||||
|
|
||||||
if len(result) != len(tt.expected) {
|
if len(result) != len(tt.expected) {
|
||||||
t.Errorf("Expected %d messages, got %d", len(tt.expected), len(result))
|
t.Errorf("Expected %d messages, got %d", len(tt.expected), len(result))
|
||||||
t.Logf("Result: %+v", result)
|
t.Logf("Result: %+v", result)
|
||||||
t.Logf("Expected: %+v", tt.expected)
|
t.Logf("Expected: %+v", tt.expected)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, expectedMsg := range tt.expected {
|
for i, expectedMsg := range tt.expected {
|
||||||
if i >= len(result) {
|
if i >= len(result) {
|
||||||
t.Errorf("Result has fewer messages than expected at index %d", i)
|
t.Errorf("Result has fewer messages than expected at index %d", i)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
actualMsg := result[i]
|
actualMsg := result[i]
|
||||||
if actualMsg.Role != expectedMsg.Role {
|
if actualMsg.Role != expectedMsg.Role {
|
||||||
t.Errorf("Message %d: expected role '%s', got '%s'", i, expectedMsg.Role, actualMsg.Role)
|
t.Errorf("Message %d: expected role '%s', got '%s'", i, expectedMsg.Role, actualMsg.Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if actualMsg.Content != expectedMsg.Content {
|
if actualMsg.Content != expectedMsg.Content {
|
||||||
t.Errorf("Message %d: expected content '%s', got '%s'", i, expectedMsg.Content, actualMsg.Content)
|
t.Errorf("Message %d: expected content '%s', got '%s'", i, expectedMsg.Content, actualMsg.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
if actualMsg.ToolCallID != expectedMsg.ToolCallID {
|
if actualMsg.ToolCallID != expectedMsg.ToolCallID {
|
||||||
t.Errorf("Message %d: expected ToolCallID '%s', got '%s'", i, expectedMsg.ToolCallID, actualMsg.ToolCallID)
|
t.Errorf("Message %d: expected ToolCallID '%s', got '%s'", i, expectedMsg.ToolCallID, actualMsg.ToolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Additional check: ensure no messages were lost
|
// Additional check: ensure no messages were lost
|
||||||
if !reflect.DeepEqual(result, tt.expected) {
|
if !reflect.DeepEqual(result, tt.expected) {
|
||||||
t.Errorf("Result does not match expected:\nResult: %+v\nExpected: %+v", result, tt.expected)
|
t.Errorf("Result does not match expected:\nResult: %+v\nExpected: %+v", result, tt.expected)
|
||||||
@@ -153,7 +143,6 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalFuncCall(t *testing.T) {
|
func TestUnmarshalFuncCall(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -213,7 +202,6 @@ func TestUnmarshalFuncCall(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := unmarshalFuncCall(tt.jsonStr)
|
got, err := unmarshalFuncCall(tt.jsonStr)
|
||||||
@@ -238,7 +226,6 @@ func TestUnmarshalFuncCall(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertJSONToMapStringString(t *testing.T) {
|
func TestConvertJSONToMapStringString(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -265,7 +252,6 @@ func TestConvertJSONToMapStringString(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := convertJSONToMapStringString(tt.jsonStr)
|
got, err := convertJSONToMapStringString(tt.jsonStr)
|
||||||
@@ -287,7 +273,6 @@ func TestConvertJSONToMapStringString(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseKnownToTag(t *testing.T) {
|
func TestParseKnownToTag(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -378,7 +363,6 @@ func TestParseKnownToTag(t *testing.T) {
|
|||||||
wantKnownTo: []string{"Alice", "Bob", "Carl"},
|
wantKnownTo: []string{"Alice", "Bob", "Carl"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Set up config
|
// Set up config
|
||||||
@@ -402,7 +386,6 @@ func TestParseKnownToTag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessMessageTag(t *testing.T) {
|
func TestProcessMessageTag(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -498,7 +481,6 @@ func TestProcessMessageTag(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testCfg := &config.Config{
|
testCfg := &config.Config{
|
||||||
@@ -529,7 +511,6 @@ func TestProcessMessageTag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilterMessagesForCharacter(t *testing.T) {
|
func TestFilterMessagesForCharacter(t *testing.T) {
|
||||||
messages := []models.RoleMsg{
|
messages := []models.RoleMsg{
|
||||||
{Role: "system", Content: "System message", KnownTo: nil}, // visible to all
|
{Role: "system", Content: "System message", KnownTo: nil}, // visible to all
|
||||||
@@ -539,7 +520,6 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
{Role: "Alice", Content: "Private to Carl", KnownTo: []string{"Alice", "Carl"}},
|
{Role: "Alice", Content: "Private to Carl", KnownTo: []string{"Alice", "Carl"}},
|
||||||
{Role: "Carl", Content: "Hi all", KnownTo: nil}, // visible to all
|
{Role: "Carl", Content: "Hi all", KnownTo: nil}, // visible to all
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
enabled bool
|
enabled bool
|
||||||
@@ -583,7 +563,6 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
wantIndices: []int{0, 1, 5},
|
wantIndices: []int{0, 1, 5},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testCfg := &config.Config{
|
testCfg := &config.Config{
|
||||||
@@ -591,15 +570,12 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
CharSpecificContextTag: "@",
|
CharSpecificContextTag: "@",
|
||||||
}
|
}
|
||||||
cfg = testCfg
|
cfg = testCfg
|
||||||
|
|
||||||
got := filterMessagesForCharacter(messages, tt.character)
|
got := filterMessagesForCharacter(messages, tt.character)
|
||||||
|
|
||||||
if len(got) != len(tt.wantIndices) {
|
if len(got) != len(tt.wantIndices) {
|
||||||
t.Errorf("filterMessagesForCharacter() returned %d messages, want %d", len(got), len(tt.wantIndices))
|
t.Errorf("filterMessagesForCharacter() returned %d messages, want %d", len(got), len(tt.wantIndices))
|
||||||
t.Logf("got: %v", got)
|
t.Logf("got: %v", got)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, idx := range tt.wantIndices {
|
for i, idx := range tt.wantIndices {
|
||||||
if got[i].Content != messages[idx].Content {
|
if got[i].Content != messages[idx].Content {
|
||||||
t.Errorf("filterMessagesForCharacter() message %d content = %q, want %q", i, got[i].Content, messages[idx].Content)
|
t.Errorf("filterMessagesForCharacter() message %d content = %q, want %q", i, got[i].Content, messages[idx].Content)
|
||||||
@@ -608,7 +584,6 @@ func TestFilterMessagesForCharacter(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
||||||
// Test that the Copy() method preserves the KnownTo field
|
// Test that the Copy() method preserves the KnownTo field
|
||||||
originalMsg := models.RoleMsg{
|
originalMsg := models.RoleMsg{
|
||||||
@@ -616,9 +591,7 @@ func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
|||||||
Content: "Test message",
|
Content: "Test message",
|
||||||
KnownTo: []string{"Bob", "Charlie"},
|
KnownTo: []string{"Bob", "Charlie"},
|
||||||
}
|
}
|
||||||
|
|
||||||
copiedMsg := originalMsg.Copy()
|
copiedMsg := originalMsg.Copy()
|
||||||
|
|
||||||
if copiedMsg.Role != originalMsg.Role {
|
if copiedMsg.Role != originalMsg.Role {
|
||||||
t.Errorf("Copy() failed to preserve Role: got %q, want %q", copiedMsg.Role, originalMsg.Role)
|
t.Errorf("Copy() failed to preserve Role: got %q, want %q", copiedMsg.Role, originalMsg.Role)
|
||||||
}
|
}
|
||||||
@@ -635,7 +608,6 @@ func TestRoleMsgCopyPreservesKnownTo(t *testing.T) {
|
|||||||
t.Errorf("Copy() failed to preserve hasContentParts flag")
|
t.Errorf("Copy() failed to preserve hasContentParts flag")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKnownToFieldPreservationScenario(t *testing.T) {
|
func TestKnownToFieldPreservationScenario(t *testing.T) {
|
||||||
// Test the specific scenario from the log where KnownTo field was getting lost
|
// Test the specific scenario from the log where KnownTo field was getting lost
|
||||||
originalMsg := models.RoleMsg{
|
originalMsg := models.RoleMsg{
|
||||||
@@ -643,28 +615,22 @@ func TestKnownToFieldPreservationScenario(t *testing.T) {
|
|||||||
Content: `Alice: "Okay, Bob. The word is... **'Ephemeral'**. (ooc: @Bob@)"`,
|
Content: `Alice: "Okay, Bob. The word is... **'Ephemeral'**. (ooc: @Bob@)"`,
|
||||||
KnownTo: []string{"Bob"}, // This was detected in the log
|
KnownTo: []string{"Bob"}, // This was detected in the log
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Original message - Role: %s, Content: %s, KnownTo: %v",
|
t.Logf("Original message - Role: %s, Content: %s, KnownTo: %v",
|
||||||
originalMsg.Role, originalMsg.Content, originalMsg.KnownTo)
|
originalMsg.Role, originalMsg.Content, originalMsg.KnownTo)
|
||||||
|
|
||||||
// Simulate what happens when the message gets copied during processing
|
// Simulate what happens when the message gets copied during processing
|
||||||
copiedMsg := originalMsg.Copy()
|
copiedMsg := originalMsg.Copy()
|
||||||
|
|
||||||
t.Logf("Copied message - Role: %s, Content: %s, KnownTo: %v",
|
t.Logf("Copied message - Role: %s, Content: %s, KnownTo: %v",
|
||||||
copiedMsg.Role, copiedMsg.Content, copiedMsg.KnownTo)
|
copiedMsg.Role, copiedMsg.Content, copiedMsg.KnownTo)
|
||||||
|
|
||||||
// Check if KnownTo field survived the copy
|
// Check if KnownTo field survived the copy
|
||||||
if len(copiedMsg.KnownTo) == 0 {
|
if len(copiedMsg.KnownTo) == 0 {
|
||||||
t.Error("ERROR: KnownTo field was lost during copy!")
|
t.Error("ERROR: KnownTo field was lost during copy!")
|
||||||
} else {
|
} else {
|
||||||
t.Log("SUCCESS: KnownTo field was preserved during copy!")
|
t.Log("SUCCESS: KnownTo field was preserved during copy!")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the content is the same
|
// Verify the content is the same
|
||||||
if copiedMsg.Content != originalMsg.Content {
|
if copiedMsg.Content != originalMsg.Content {
|
||||||
t.Errorf("Content was changed during copy: got %s, want %s", copiedMsg.Content, originalMsg.Content)
|
t.Errorf("Content was changed during copy: got %s, want %s", copiedMsg.Content, originalMsg.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the KnownTo slice is properly copied
|
// Verify the KnownTo slice is properly copied
|
||||||
if !reflect.DeepEqual(copiedMsg.KnownTo, originalMsg.KnownTo) {
|
if !reflect.DeepEqual(copiedMsg.KnownTo, originalMsg.KnownTo) {
|
||||||
t.Errorf("KnownTo was not properly copied: got %v, want %v", copiedMsg.KnownTo, originalMsg.KnownTo)
|
t.Errorf("KnownTo was not properly copied: got %v, want %v", copiedMsg.KnownTo, originalMsg.KnownTo)
|
||||||
|
|||||||
12
helpfuncs.go
12
helpfuncs.go
@@ -32,7 +32,6 @@ func startModelColorUpdater() {
|
|||||||
|
|
||||||
// Initial check
|
// Initial check
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
}
|
}
|
||||||
@@ -75,15 +74,16 @@ func stripThinkingFromMsg(msg *models.RoleMsg) *models.RoleMsg {
|
|||||||
if !cfg.StripThinkingFromAPI {
|
if !cfg.StripThinkingFromAPI {
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
// Skip user, tool, and system messages - they might contain thinking examples
|
// Skip user, tool, they might contain thinking and system messages - examples
|
||||||
if msg.Role == cfg.UserRole || msg.Role == cfg.ToolRole || msg.Role == "system" {
|
if msg.Role == cfg.UserRole || msg.Role == cfg.ToolRole || msg.Role == "system" {
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
// Strip thinking from assistant messages
|
// Strip thinking from assistant messages
|
||||||
if thinkRE.MatchString(msg.Content) {
|
msgText := msg.GetText()
|
||||||
msg.Content = thinkRE.ReplaceAllString(msg.Content, "")
|
if thinkRE.MatchString(msgText) {
|
||||||
// Clean up any double newlines that might result
|
cleanedText := thinkRE.ReplaceAllString(msgText, "")
|
||||||
msg.Content = strings.TrimSpace(msg.Content)
|
cleanedText = strings.TrimSpace(cleanedText)
|
||||||
|
msg.SetText(cleanedText)
|
||||||
}
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|||||||
3
llm.go
3
llm.go
@@ -216,13 +216,11 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
|
|||||||
logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data))
|
logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data))
|
||||||
return &models.TextChunk{Finished: true}, nil
|
return &models.TextChunk{Finished: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1]
|
lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1]
|
||||||
resp := &models.TextChunk{
|
resp := &models.TextChunk{
|
||||||
Chunk: lastChoice.Delta.Content,
|
Chunk: lastChoice.Delta.Content,
|
||||||
Reasoning: lastChoice.Delta.ReasoningContent,
|
Reasoning: lastChoice.Delta.ReasoningContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool calls in all choices, not just the last one
|
// Check for tool calls in all choices, not just the last one
|
||||||
for _, choice := range llmchunk.Choices {
|
for _, choice := range llmchunk.Choices {
|
||||||
if len(choice.Delta.ToolCalls) > 0 {
|
if len(choice.Delta.ToolCalls) > 0 {
|
||||||
@@ -237,7 +235,6 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
|
|||||||
break // Process only the first tool call
|
break // Process only the first tool call
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastChoice.FinishReason == "stop" {
|
if lastChoice.FinishReason == "stop" {
|
||||||
if resp.Chunk != "" {
|
if resp.Chunk != "" {
|
||||||
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
|
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
|
||||||
|
|||||||
@@ -329,6 +329,66 @@ func (m *RoleMsg) Copy() RoleMsg {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetText returns the text content of the message, handling both
|
||||||
|
// simple Content and multimodal ContentParts formats.
|
||||||
|
func (m *RoleMsg) GetText() string {
|
||||||
|
if !m.hasContentParts {
|
||||||
|
return m.Content
|
||||||
|
}
|
||||||
|
var textParts []string
|
||||||
|
for _, part := range m.ContentParts {
|
||||||
|
switch p := part.(type) {
|
||||||
|
case TextContentPart:
|
||||||
|
if p.Type == "text" {
|
||||||
|
textParts = append(textParts, p.Text)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
if partType, exists := p["type"]; exists {
|
||||||
|
if partType == "text" {
|
||||||
|
if textVal, textExists := p["text"]; textExists {
|
||||||
|
if textStr, isStr := textVal.(string); isStr {
|
||||||
|
textParts = append(textParts, textStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(textParts, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetText updates the text content of the message. If the message has
|
||||||
|
// ContentParts (multimodal), it updates the text parts while preserving
|
||||||
|
// images. If not, it sets the simple Content field.
|
||||||
|
func (m *RoleMsg) SetText(text string) {
|
||||||
|
if !m.hasContentParts {
|
||||||
|
m.Content = text
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var newParts []any
|
||||||
|
for _, part := range m.ContentParts {
|
||||||
|
switch p := part.(type) {
|
||||||
|
case TextContentPart:
|
||||||
|
if p.Type == "text" {
|
||||||
|
p.Text = text
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
} else {
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
if partType, exists := p["type"]; exists && partType == "text" {
|
||||||
|
p["text"] = text
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
} else {
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
newParts = append(newParts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.ContentParts = newParts
|
||||||
|
}
|
||||||
|
|
||||||
// AddTextPart adds a text content part to the message
|
// AddTextPart adds a text content part to the message
|
||||||
func (m *RoleMsg) AddTextPart(text string) {
|
func (m *RoleMsg) AddTextPart(text string) {
|
||||||
if !m.hasContentParts {
|
if !m.hasContentParts {
|
||||||
@@ -340,7 +400,6 @@ func (m *RoleMsg) AddTextPart(text string) {
|
|||||||
}
|
}
|
||||||
m.hasContentParts = true
|
m.hasContentParts = true
|
||||||
}
|
}
|
||||||
|
|
||||||
textPart := TextContentPart{Type: "text", Text: text}
|
textPart := TextContentPart{Type: "text", Text: text}
|
||||||
m.ContentParts = append(m.ContentParts, textPart)
|
m.ContentParts = append(m.ContentParts, textPart)
|
||||||
}
|
}
|
||||||
@@ -356,7 +415,6 @@ func (m *RoleMsg) AddImagePart(imageURL, imagePath string) {
|
|||||||
}
|
}
|
||||||
m.hasContentParts = true
|
m.hasContentParts = true
|
||||||
}
|
}
|
||||||
|
|
||||||
imagePart := ImageContentPart{
|
imagePart := ImageContentPart{
|
||||||
Type: "image_url",
|
Type: "image_url",
|
||||||
Path: imagePath, // Store the original file path
|
Path: imagePath, // Store the original file path
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRoleMsgToTextWithImages(t *testing.T) {
|
func TestRoleMsgToTextWithImages(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -92,7 +90,6 @@ func TestRoleMsgToTextWithImages(t *testing.T) {
|
|||||||
expected: "[orange::i][image: /old/path/photo.jpg][-:-:-]",
|
expected: "[orange::i][image: /old/path/photo.jpg][-:-:-]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := tt.msg.ToText(tt.index)
|
result := tt.msg.ToText(tt.index)
|
||||||
@@ -110,12 +107,10 @@ func TestRoleMsgToTextWithImages(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractDisplayPath(t *testing.T) {
|
func TestExtractDisplayPath(t *testing.T) {
|
||||||
// Save original base dir
|
// Save original base dir
|
||||||
originalBaseDir := imageBaseDir
|
originalBaseDir := imageBaseDir
|
||||||
defer func() { imageBaseDir = originalBaseDir }()
|
defer func() { imageBaseDir = originalBaseDir }()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
baseDir string
|
baseDir string
|
||||||
@@ -153,7 +148,6 @@ func TestExtractDisplayPath(t *testing.T) {
|
|||||||
expected: "..._that_exceeds_sixty_characters_limit_yes_it_is_very_long.jpg",
|
expected: "..._that_exceeds_sixty_characters_limit_yes_it_is_very_long.jpg",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
imageBaseDir = tt.baseDir
|
imageBaseDir = tt.baseDir
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ func TestORModelsListModels(t *testing.T) {
|
|||||||
t.Errorf("expected 4 total models, got %d", len(allModels))
|
t.Errorf("expected 4 total models, got %d", len(allModels))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("integration with or_models.json", func(t *testing.T) {
|
t.Run("integration with or_models.json", func(t *testing.T) {
|
||||||
// Attempt to load the real data file from the project root
|
// Attempt to load the real data file from the project root
|
||||||
path := filepath.Join("..", "or_models.json")
|
path := filepath.Join("..", "or_models.json")
|
||||||
|
|||||||
@@ -410,38 +410,30 @@ func updateWidgetColors(theme *tview.Theme) {
|
|||||||
fgColor := theme.PrimaryTextColor
|
fgColor := theme.PrimaryTextColor
|
||||||
borderColor := theme.BorderColor
|
borderColor := theme.BorderColor
|
||||||
titleColor := theme.TitleColor
|
titleColor := theme.TitleColor
|
||||||
|
|
||||||
textView.SetBackgroundColor(bgColor)
|
textView.SetBackgroundColor(bgColor)
|
||||||
textView.SetTextColor(fgColor)
|
textView.SetTextColor(fgColor)
|
||||||
textView.SetBorderColor(borderColor)
|
textView.SetBorderColor(borderColor)
|
||||||
textView.SetTitleColor(titleColor)
|
textView.SetTitleColor(titleColor)
|
||||||
|
|
||||||
textArea.SetBackgroundColor(bgColor)
|
textArea.SetBackgroundColor(bgColor)
|
||||||
textArea.SetBorderColor(borderColor)
|
textArea.SetBorderColor(borderColor)
|
||||||
textArea.SetTitleColor(titleColor)
|
textArea.SetTitleColor(titleColor)
|
||||||
textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
|
|
||||||
textArea.SetText(textArea.GetText(), true)
|
textArea.SetText(textArea.GetText(), true)
|
||||||
|
|
||||||
editArea.SetBackgroundColor(bgColor)
|
editArea.SetBackgroundColor(bgColor)
|
||||||
editArea.SetBorderColor(borderColor)
|
editArea.SetBorderColor(borderColor)
|
||||||
editArea.SetTitleColor(titleColor)
|
editArea.SetTitleColor(titleColor)
|
||||||
editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
|
|
||||||
editArea.SetText(editArea.GetText(), true)
|
editArea.SetText(editArea.GetText(), true)
|
||||||
|
|
||||||
statusLineWidget.SetBackgroundColor(bgColor)
|
statusLineWidget.SetBackgroundColor(bgColor)
|
||||||
statusLineWidget.SetTextColor(fgColor)
|
statusLineWidget.SetTextColor(fgColor)
|
||||||
statusLineWidget.SetBorderColor(borderColor)
|
statusLineWidget.SetBorderColor(borderColor)
|
||||||
statusLineWidget.SetTitleColor(titleColor)
|
statusLineWidget.SetTitleColor(titleColor)
|
||||||
|
|
||||||
helpView.SetBackgroundColor(bgColor)
|
helpView.SetBackgroundColor(bgColor)
|
||||||
helpView.SetTextColor(fgColor)
|
helpView.SetTextColor(fgColor)
|
||||||
helpView.SetBorderColor(borderColor)
|
helpView.SetBorderColor(borderColor)
|
||||||
helpView.SetTitleColor(titleColor)
|
helpView.SetTitleColor(titleColor)
|
||||||
|
|
||||||
searchField.SetBackgroundColor(bgColor)
|
searchField.SetBackgroundColor(bgColor)
|
||||||
searchField.SetBorderColor(borderColor)
|
searchField.SetBorderColor(borderColor)
|
||||||
searchField.SetTitleColor(titleColor)
|
searchField.SetTitleColor(titleColor)
|
||||||
@@ -468,7 +460,6 @@ func showColorschemeSelectionPopup() {
|
|||||||
schemeListWidget := tview.NewList().ShowSecondaryText(false).
|
schemeListWidget := tview.NewList().ShowSecondaryText(false).
|
||||||
SetSelectedBackgroundColor(tcell.ColorGray)
|
SetSelectedBackgroundColor(tcell.ColorGray)
|
||||||
schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true)
|
schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true)
|
||||||
|
|
||||||
currentScheme := "default"
|
currentScheme := "default"
|
||||||
for name := range colorschemes {
|
for name := range colorschemes {
|
||||||
if tview.Styles == colorschemes[name] {
|
if tview.Styles == colorschemes[name] {
|
||||||
|
|||||||
@@ -131,7 +131,6 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
|||||||
}
|
}
|
||||||
embeddings[data.Index] = data.Embedding
|
embeddings[data.Index] = data.Embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
return embeddings, nil
|
return embeddings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -95,9 +95,7 @@ func extractTextFromEpub(fpath string) (string, error) {
|
|||||||
return "", fmt.Errorf("failed to open epub: %w", err)
|
return "", fmt.Errorf("failed to open epub: %w", err)
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
for _, f := range r.File {
|
for _, f := range r.File {
|
||||||
ext := strings.ToLower(path.Ext(f.Name))
|
ext := strings.ToLower(path.Ext(f.Name))
|
||||||
if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" {
|
if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" {
|
||||||
@@ -129,7 +127,6 @@ func extractTextFromEpub(fpath string) (string, error) {
|
|||||||
sb.WriteString(stripHTML(string(buf)))
|
sb.WriteString(stripHTML(string(buf)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sb.Len() == 0 {
|
if sb.Len() == 0 {
|
||||||
return "", errors.New("no content extracted from epub")
|
return "", errors.New("no content extracted from epub")
|
||||||
}
|
}
|
||||||
|
|||||||
55
rag/rag.go
55
rag/rag.go
@@ -36,7 +36,6 @@ type RAG struct {
|
|||||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
||||||
// Initialize with API embedder by default, could be configurable later
|
// Initialize with API embedder by default, could be configurable later
|
||||||
embedder := NewAPIEmbedder(l, cfg)
|
embedder := NewAPIEmbedder(l, cfg)
|
||||||
|
|
||||||
rag := &RAG{
|
rag := &RAG{
|
||||||
logger: l,
|
logger: l,
|
||||||
store: s,
|
store: s,
|
||||||
@@ -205,29 +204,22 @@ var (
|
|||||||
func (r *RAG) RefineQuery(query string) string {
|
func (r *RAG) RefineQuery(query string) string {
|
||||||
original := query
|
original := query
|
||||||
query = strings.TrimSpace(query)
|
query = strings.TrimSpace(query)
|
||||||
|
|
||||||
if len(query) == 0 {
|
if len(query) == 0 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(query) <= 3 {
|
if len(query) <= 3 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
query = strings.ToLower(query)
|
query = strings.ToLower(query)
|
||||||
|
|
||||||
for _, stopWord := range stopWords {
|
for _, stopWord := range stopWords {
|
||||||
wordPattern := `\b` + stopWord + `\b`
|
wordPattern := `\b` + stopWord + `\b`
|
||||||
re := regexp.MustCompile(wordPattern)
|
re := regexp.MustCompile(wordPattern)
|
||||||
query = re.ReplaceAllString(query, "")
|
query = re.ReplaceAllString(query, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
query = strings.TrimSpace(query)
|
query = strings.TrimSpace(query)
|
||||||
|
|
||||||
if len(query) < 5 {
|
if len(query) < 5 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
if queryRefinementPattern.MatchString(original) {
|
if queryRefinementPattern.MatchString(original) {
|
||||||
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
|
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
|
||||||
cleaned = strings.TrimSpace(cleaned)
|
cleaned = strings.TrimSpace(cleaned)
|
||||||
@@ -235,23 +227,18 @@ func (r *RAG) RefineQuery(query string) string {
|
|||||||
return cleaned
|
return cleaned
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query = r.extractImportantPhrases(query)
|
query = r.extractImportantPhrases(query)
|
||||||
|
|
||||||
if len(query) < 5 {
|
if len(query) < 5 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) extractImportantPhrases(query string) string {
|
func (r *RAG) extractImportantPhrases(query string) string {
|
||||||
words := strings.Fields(query)
|
words := strings.Fields(query)
|
||||||
|
|
||||||
var important []string
|
var important []string
|
||||||
for _, word := range words {
|
for _, word := range words {
|
||||||
word = strings.Trim(word, ".,!?;:'\"()[]{}")
|
word = strings.Trim(word, ".,!?;:'\"()[]{}")
|
||||||
|
|
||||||
isImportant := false
|
isImportant := false
|
||||||
for _, kw := range importantKeywords {
|
for _, kw := range importantKeywords {
|
||||||
if strings.Contains(strings.ToLower(word), kw) {
|
if strings.Contains(strings.ToLower(word), kw) {
|
||||||
@@ -259,45 +246,37 @@ func (r *RAG) extractImportantPhrases(query string) string {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isImportant || len(word) > 3 {
|
if isImportant || len(word) > 3 {
|
||||||
important = append(important, word)
|
important = append(important, word)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(important) == 0 {
|
if len(important) == 0 {
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(important, " ")
|
return strings.Join(important, " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) GenerateQueryVariations(query string) []string {
|
func (r *RAG) GenerateQueryVariations(query string) []string {
|
||||||
variations := []string{query}
|
variations := []string{query}
|
||||||
|
|
||||||
if len(query) < 5 {
|
if len(query) < 5 {
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Fields(query)
|
parts := strings.Fields(query)
|
||||||
if len(parts) == 0 {
|
if len(parts) == 0 {
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
trimmed := strings.Join(parts[:len(parts)-1], " ")
|
trimmed := strings.Join(parts[:len(parts)-1], " ")
|
||||||
if len(trimmed) >= 5 {
|
if len(trimmed) >= 5 {
|
||||||
variations = append(variations, trimmed)
|
variations = append(variations, trimmed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
trimmed := strings.Join(parts[1:], " ")
|
trimmed := strings.Join(parts[1:], " ")
|
||||||
if len(trimmed) >= 5 {
|
if len(trimmed) >= 5 {
|
||||||
variations = append(variations, trimmed)
|
variations = append(variations, trimmed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasSuffix(query, " explanation") {
|
if !strings.HasSuffix(query, " explanation") {
|
||||||
variations = append(variations, query+" explanation")
|
variations = append(variations, query+" explanation")
|
||||||
}
|
}
|
||||||
@@ -310,7 +289,6 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
|||||||
if !strings.HasSuffix(query, " summary") {
|
if !strings.HasSuffix(query, " summary") {
|
||||||
variations = append(variations, query+" summary")
|
variations = append(variations, query+" summary")
|
||||||
}
|
}
|
||||||
|
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,21 +297,16 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
|||||||
row models.VectorRow
|
row models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
scored := make([]scoredResult, 0, len(results))
|
scored := make([]scoredResult, 0, len(results))
|
||||||
|
|
||||||
for i := range results {
|
for i := range results {
|
||||||
row := results[i]
|
row := results[i]
|
||||||
|
|
||||||
score := float32(0)
|
score := float32(0)
|
||||||
|
|
||||||
rawTextLower := strings.ToLower(row.RawText)
|
rawTextLower := strings.ToLower(row.RawText)
|
||||||
queryLower := strings.ToLower(query)
|
queryLower := strings.ToLower(query)
|
||||||
|
|
||||||
if strings.Contains(rawTextLower, queryLower) {
|
if strings.Contains(rawTextLower, queryLower) {
|
||||||
score += 10
|
score += 10
|
||||||
}
|
}
|
||||||
|
|
||||||
queryWords := strings.Fields(queryLower)
|
queryWords := strings.Fields(queryLower)
|
||||||
matchCount := 0
|
matchCount := 0
|
||||||
for _, word := range queryWords {
|
for _, word := range queryWords {
|
||||||
@@ -344,34 +317,26 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
|||||||
if len(queryWords) > 0 {
|
if len(queryWords) > 0 {
|
||||||
score += float32(matchCount) / float32(len(queryWords)) * 5
|
score += float32(matchCount) / float32(len(queryWords)) * 5
|
||||||
}
|
}
|
||||||
|
|
||||||
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
||||||
score += 3
|
score += 3
|
||||||
}
|
}
|
||||||
|
|
||||||
distance := row.Distance - score/100
|
distance := row.Distance - score/100
|
||||||
|
|
||||||
scored = append(scored, scoredResult{row: row, distance: distance})
|
scored = append(scored, scoredResult{row: row, distance: distance})
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(scored, func(i, j int) bool {
|
sort.Slice(scored, func(i, j int) bool {
|
||||||
return scored[i].distance < scored[j].distance
|
return scored[i].distance < scored[j].distance
|
||||||
})
|
})
|
||||||
|
|
||||||
unique := make([]models.VectorRow, 0)
|
unique := make([]models.VectorRow, 0)
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
for i := range scored {
|
for i := range scored {
|
||||||
if !seen[scored[i].row.Slug] {
|
if !seen[scored[i].row.Slug] {
|
||||||
seen[scored[i].row.Slug] = true
|
seen[scored[i].row.Slug] = true
|
||||||
unique = append(unique, scored[i].row)
|
unique = append(unique, scored[i].row)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(unique) > 10 {
|
if len(unique) > 10 {
|
||||||
unique = unique[:10]
|
unique = unique[:10]
|
||||||
}
|
}
|
||||||
|
|
||||||
return unique
|
return unique
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,58 +344,47 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
|
|||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
return "No relevant information found in the vector database.", nil
|
return "No relevant information found in the vector database.", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var contextBuilder strings.Builder
|
var contextBuilder strings.Builder
|
||||||
contextBuilder.WriteString("User Query: ")
|
contextBuilder.WriteString("User Query: ")
|
||||||
contextBuilder.WriteString(query)
|
contextBuilder.WriteString(query)
|
||||||
contextBuilder.WriteString("\n\nRetrieved Context:\n")
|
contextBuilder.WriteString("\n\nRetrieved Context:\n")
|
||||||
|
|
||||||
for i, row := range results {
|
for i, row := range results {
|
||||||
contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName))
|
fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName)
|
||||||
contextBuilder.WriteString(row.RawText)
|
contextBuilder.WriteString(row.RawText)
|
||||||
contextBuilder.WriteString("\n\n")
|
contextBuilder.WriteString("\n\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
contextBuilder.WriteString("Instructions: ")
|
contextBuilder.WriteString("Instructions: ")
|
||||||
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
|
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
|
||||||
contextBuilder.WriteString("Extract only the most relevant information. ")
|
contextBuilder.WriteString("Extract only the most relevant information. ")
|
||||||
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
|
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
|
||||||
contextBuilder.WriteString("Cite sources by filename when relevant. ")
|
contextBuilder.WriteString("Cite sources by filename when relevant. ")
|
||||||
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
|
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
|
||||||
|
|
||||||
synthesisPrompt := contextBuilder.String()
|
synthesisPrompt := contextBuilder.String()
|
||||||
|
|
||||||
emb, err := r.LineToVector(synthesisPrompt)
|
emb, err := r.LineToVector(synthesisPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to embed synthesis prompt", "error", err)
|
r.logger.Error("failed to embed synthesis prompt", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
embResp := &models.EmbeddingResp{
|
embResp := &models.EmbeddingResp{
|
||||||
Embedding: emb,
|
Embedding: emb,
|
||||||
Index: 0,
|
Index: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
topResults, err := r.SearchEmb(embResp)
|
topResults, err := r.SearchEmb(embResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to search for synthesis context", "error", err)
|
r.logger.Error("failed to search for synthesis context", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
|
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
|
||||||
return topResults[0].RawText, nil
|
return topResults[0].RawText, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var finalAnswer strings.Builder
|
var finalAnswer strings.Builder
|
||||||
finalAnswer.WriteString("Based on the retrieved context:\n\n")
|
finalAnswer.WriteString("Based on the retrieved context:\n\n")
|
||||||
|
|
||||||
for i, row := range results {
|
for i, row := range results {
|
||||||
if i >= 5 {
|
if i >= 5 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)))
|
fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))
|
||||||
}
|
}
|
||||||
|
|
||||||
return finalAnswer.String(), nil
|
return finalAnswer.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -444,10 +398,8 @@ func truncateString(s string, maxLen int) string {
|
|||||||
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||||
refined := r.RefineQuery(query)
|
refined := r.RefineQuery(query)
|
||||||
variations := r.GenerateQueryVariations(refined)
|
variations := r.GenerateQueryVariations(refined)
|
||||||
|
|
||||||
allResults := make([]models.VectorRow, 0)
|
allResults := make([]models.VectorRow, 0)
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
for _, q := range variations {
|
for _, q := range variations {
|
||||||
emb, err := r.LineToVector(q)
|
emb, err := r.LineToVector(q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -473,13 +425,10 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reranked := r.RerankResults(allResults, query)
|
reranked := r.RerankResults(allResults, query)
|
||||||
|
|
||||||
if len(reranked) > limit {
|
if len(reranked) > limit {
|
||||||
reranked = reranked[:limit]
|
reranked = reranked[:limit]
|
||||||
}
|
}
|
||||||
|
|
||||||
return reranked, nil
|
return reranked, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// SerializeVector converts []float32 to binary blob
|
// SerializeVector converts []float32 to binary blob
|
||||||
func SerializeVector(vec []float32) []byte {
|
func SerializeVector(vec []float32) []byte {
|
||||||
buf := make([]byte, len(vec)*4) // 4 bytes per float32
|
buf := make([]byte, len(vec)*4) // 4 bytes per float32
|
||||||
@@ -66,17 +65,14 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
|||||||
|
|
||||||
// Serialize the embeddings to binary
|
// Serialize the embeddings to binary
|
||||||
serializedEmbeddings := SerializeVector(row.Embeddings)
|
serializedEmbeddings := SerializeVector(row.Embeddings)
|
||||||
|
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
||||||
tableName,
|
tableName,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
||||||
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,20 +82,18 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
|
|||||||
|
|
||||||
// Check if we support this embedding size
|
// Check if we support this embedding size
|
||||||
supportedSizes := map[int]bool{
|
supportedSizes := map[int]bool{
|
||||||
384: true,
|
384: true,
|
||||||
768: true,
|
768: true,
|
||||||
1024: true,
|
1024: true,
|
||||||
1536: true,
|
1536: true,
|
||||||
2048: true,
|
2048: true,
|
||||||
3072: true,
|
3072: true,
|
||||||
4096: true,
|
4096: true,
|
||||||
5120: true,
|
5120: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if supportedSizes[size] {
|
if supportedSizes[size] {
|
||||||
return fmt.Sprintf("embeddings_%d", size), nil
|
return fmt.Sprintf("embeddings_%d", size), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("no table for embedding size of %d", size)
|
return "", fmt.Errorf("no table for embedding size of %d", size)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,9 +120,7 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
var topResults []SearchResult
|
var topResults []SearchResult
|
||||||
|
|
||||||
// Process vectors one by one to avoid loading everything into memory
|
// Process vectors one by one to avoid loading everything into memory
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
@@ -176,14 +168,12 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
results = append(results, result.vector)
|
results = append(results, result.vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFiles returns a list of all loaded files
|
// ListFiles returns a list of all loaded files
|
||||||
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
||||||
fileLists := make([][]string, 0)
|
fileLists := make([][]string, 0)
|
||||||
|
|
||||||
// Query all supported tables and combine results
|
// Query all supported tables and combine results
|
||||||
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
for _, size := range embeddingSizes {
|
for _, size := range embeddingSizes {
|
||||||
@@ -219,14 +209,12 @@ func (vs *VectorStorage) ListFiles() ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return allFiles, nil
|
return allFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
||||||
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
for _, size := range embeddingSizes {
|
for _, size := range embeddingSizes {
|
||||||
table := fmt.Sprintf("embeddings_%d", size)
|
table := fmt.Sprintf("embeddings_%d", size)
|
||||||
@@ -235,11 +223,9 @@ func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
|||||||
errors = append(errors, err.Error())
|
errors = append(errors, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
|
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,18 +234,15 @@ func cosineSimilarity(a, b []float32) float32 {
|
|||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var dotProduct, normA, normB float32
|
var dotProduct, normA, normB float32
|
||||||
for i := 0; i < len(a); i++ {
|
for i := 0; i < len(a); i++ {
|
||||||
dotProduct += a[i] * b[i]
|
dotProduct += a[i] * b[i]
|
||||||
normA += a[i] * a[i]
|
normA += a[i] * a[i]
|
||||||
normB += b[i] * b[i]
|
normB += b[i] * b[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if normA == 0 || normB == 0 {
|
if normA == 0 || normB == 0 {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,4 +258,3 @@ func sqrt(f float32) float32 {
|
|||||||
}
|
}
|
||||||
return guess
|
return guess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,6 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
p := ProviderSQL{db: db, logger: logger}
|
p := ProviderSQL{db: db, logger: logger}
|
||||||
|
|
||||||
p.Migrate()
|
p.Migrate()
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,12 +73,9 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
serializedEmbeddings := SerializeVector(row.Embeddings)
|
serializedEmbeddings := SerializeVector(row.Embeddings)
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
|
query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
|
||||||
_, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
|
_, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,27 +84,22 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
||||||
rows, err := p.db.Query(querySQL)
|
rows, err := p.db.Query(querySQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
type SearchResult struct {
|
type SearchResult struct {
|
||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
var topResults []SearchResult
|
var topResults []SearchResult
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
embeddingsBlob []byte
|
embeddingsBlob []byte
|
||||||
slug, rawText, fileName string
|
slug, rawText, fileName string
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
|
if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -152,7 +144,6 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
results[i] = result.vector
|
results[i] = result.vector
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,18 +152,15 @@ func cosineSimilarity(a, b []float32) float32 {
|
|||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var dotProduct, normA, normB float32
|
var dotProduct, normA, normB float32
|
||||||
for i := 0; i < len(a); i++ {
|
for i := 0; i < len(a); i++ {
|
||||||
dotProduct += a[i] * b[i]
|
dotProduct += a[i] * b[i]
|
||||||
normA += a[i] * a[i]
|
normA += a[i] * a[i]
|
||||||
normB += b[i] * b[i]
|
normB += b[i] * b[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if normA == 0 || normB == 0 {
|
if normA == 0 || normB == 0 {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,13 +217,11 @@ func (p ProviderSQL) ListFiles() ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return allFiles, nil
|
return allFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
tableNames := []string{
|
tableNames := []string{
|
||||||
"embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
|
"embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
|
||||||
"embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
|
"embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
|
||||||
@@ -246,10 +232,8 @@ func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
|||||||
errors = append(errors, err.Error())
|
errors = append(errors, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
return fmt.Errorf("errors occurred: %v", errors)
|
return fmt.Errorf("errors occurred: %v", errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
230
tables.go
230
tables.go
@@ -287,7 +287,6 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows := len(ragFiles)
|
rows := len(ragFiles)
|
||||||
cols := 4 // File Name | Preview | Action | Delete
|
cols := 4 // File Name | Preview | Action | Delete
|
||||||
fileTable := tview.NewTable().
|
fileTable := tview.NewTable().
|
||||||
@@ -327,8 +326,8 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
|||||||
f := ragFiles[r]
|
f := ragFiles[r]
|
||||||
for c := 0; c < cols; c++ {
|
for c := 0; c < cols; c++ {
|
||||||
color := tcell.ColorWhite
|
color := tcell.ColorWhite
|
||||||
switch {
|
switch c {
|
||||||
case c == 0:
|
case 0:
|
||||||
displayName := f.name
|
displayName := f.name
|
||||||
if !f.inRAGDir {
|
if !f.inRAGDir {
|
||||||
displayName = f.name + " (orphaned)"
|
displayName = f.name + " (orphaned)"
|
||||||
@@ -338,7 +337,7 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
|||||||
SetTextColor(color).
|
SetTextColor(color).
|
||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
case c == 1:
|
case 1:
|
||||||
if !f.inRAGDir {
|
if !f.inRAGDir {
|
||||||
// Orphaned file - no preview available
|
// Orphaned file - no preview available
|
||||||
fileTable.SetCell(r+1, c,
|
fileTable.SetCell(r+1, c,
|
||||||
@@ -362,7 +361,7 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
|||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
}
|
}
|
||||||
case c == 2:
|
case 2:
|
||||||
actionText := "load"
|
actionText := "load"
|
||||||
if f.isLoaded {
|
if f.isLoaded {
|
||||||
actionText = "unload"
|
actionText = "unload"
|
||||||
@@ -375,7 +374,7 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
|||||||
tview.NewTableCell(actionText).
|
tview.NewTableCell(actionText).
|
||||||
SetTextColor(color).
|
SetTextColor(color).
|
||||||
SetAlign(tview.AlignCenter))
|
SetAlign(tview.AlignCenter))
|
||||||
case c == 3:
|
case 3:
|
||||||
if !f.inRAGDir {
|
if !f.inRAGDir {
|
||||||
// Orphaned file - cannot delete from ragdir (not there)
|
// Orphaned file - cannot delete from ragdir (not there)
|
||||||
fileTable.SetCell(r+1, c,
|
fileTable.SetCell(r+1, c,
|
||||||
@@ -513,138 +512,6 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
|||||||
return ragflex
|
return ragflex
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeLoadedRAGTable(fileList []string) *tview.Flex {
|
|
||||||
actions := []string{"delete"}
|
|
||||||
rows, cols := len(fileList), len(actions)+2
|
|
||||||
// Add 1 extra row for the "exit" option at the top
|
|
||||||
fileTable := tview.NewTable().
|
|
||||||
SetBorders(true)
|
|
||||||
longStatusView := tview.NewTextView()
|
|
||||||
longStatusView.SetText("Loaded RAG files list")
|
|
||||||
longStatusView.SetBorder(true).SetTitle("status")
|
|
||||||
longStatusView.SetChangedFunc(func() {
|
|
||||||
app.Draw()
|
|
||||||
})
|
|
||||||
ragflex := tview.NewFlex().SetDirection(tview.FlexRow).
|
|
||||||
AddItem(longStatusView, 0, 10, false).
|
|
||||||
AddItem(fileTable, 0, 60, true)
|
|
||||||
// Add the exit option as the first row (row 0)
|
|
||||||
fileTable.SetCell(0, 0,
|
|
||||||
tview.NewTableCell("File Name").
|
|
||||||
SetTextColor(tcell.ColorWhite).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
fileTable.SetCell(0, 1,
|
|
||||||
tview.NewTableCell("Preview").
|
|
||||||
SetTextColor(tcell.ColorWhite).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
fileTable.SetCell(0, 2,
|
|
||||||
tview.NewTableCell("Load").
|
|
||||||
SetTextColor(tcell.ColorWhite).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
fileTable.SetCell(0, 3,
|
|
||||||
tview.NewTableCell("Delete").
|
|
||||||
SetTextColor(tcell.ColorWhite).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
// Add the file rows starting from row 1
|
|
||||||
for r := 0; r < rows; r++ {
|
|
||||||
for c := 0; c < cols; c++ {
|
|
||||||
color := tcell.ColorWhite
|
|
||||||
switch {
|
|
||||||
case c == 0:
|
|
||||||
fileTable.SetCell(r+1, c,
|
|
||||||
tview.NewTableCell(fileList[r]).
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
case c == 1:
|
|
||||||
if fi, err := os.Stat(fileList[r]); err == nil {
|
|
||||||
size := fi.Size()
|
|
||||||
modTime := fi.ModTime()
|
|
||||||
preview := fmt.Sprintf("%s | %s", formatSize(size), modTime.Format("2006-01-02 15:04"))
|
|
||||||
fileTable.SetCell(r+1, c,
|
|
||||||
tview.NewTableCell(preview).
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
} else {
|
|
||||||
fileTable.SetCell(r+1, c,
|
|
||||||
tview.NewTableCell("error").
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter).
|
|
||||||
SetSelectable(false))
|
|
||||||
}
|
|
||||||
case c == 2:
|
|
||||||
fileTable.SetCell(r+1, c,
|
|
||||||
tview.NewTableCell("load").
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter))
|
|
||||||
default:
|
|
||||||
fileTable.SetCell(r+1, c,
|
|
||||||
tview.NewTableCell("delete").
|
|
||||||
SetTextColor(color).
|
|
||||||
SetAlign(tview.AlignCenter))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fileTable.Select(0, 0).
|
|
||||||
SetFixed(1, 1).
|
|
||||||
SetSelectable(true, true).
|
|
||||||
SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)).
|
|
||||||
SetDoneFunc(func(key tcell.Key) {
|
|
||||||
if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') || key == tcell.KeyCtrlX {
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}).SetSelectedFunc(func(row int, column int) {
|
|
||||||
// If user selects a non-actionable column (0 or 1), move to first action column (2)
|
|
||||||
if column <= 1 {
|
|
||||||
if fileTable.GetColumnCount() > 2 {
|
|
||||||
fileTable.Select(row, 2) // Select first action column
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tc := fileTable.GetCell(row, column)
|
|
||||||
tc.SetTextColor(tcell.ColorRed)
|
|
||||||
fileTable.SetSelectable(false, false)
|
|
||||||
// Check if the selected row is the exit row (row 0) - do this first to avoid index issues
|
|
||||||
if row == 0 {
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// For file rows, get the filename (row index - 1 because of the exit row at index 0)
|
|
||||||
fpath := fileList[row-1] // -1 to account for the exit row at index 0
|
|
||||||
switch tc.Text {
|
|
||||||
case "delete":
|
|
||||||
if err := ragger.RemoveFile(fpath); err != nil {
|
|
||||||
logger.Error("failed to delete file from RAG", "filename", fpath, "error", err)
|
|
||||||
longStatusView.SetText(fmt.Sprintf("Error deleting file: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := notifyUser("RAG file deleted", fpath+" was deleted from RAG system"); err != nil {
|
|
||||||
logger.Error("failed to send notification", "error", err)
|
|
||||||
}
|
|
||||||
longStatusView.SetText(fpath + " was deleted from RAG system")
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
// Add input capture to the flex container to handle 'x' key for closing
|
|
||||||
ragflex.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
|
||||||
if event.Key() == tcell.KeyRune && event.Rune() == 'x' {
|
|
||||||
pages.RemovePage(RAGLoadedPage)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return event
|
|
||||||
})
|
|
||||||
return ragflex
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeAgentTable(agentList []string) *tview.Table {
|
func makeAgentTable(agentList []string) *tview.Table {
|
||||||
actions := []string{"filepath", "load"}
|
actions := []string{"filepath", "load"}
|
||||||
rows, cols := len(agentList), len(actions)+1
|
rows, cols := len(agentList), len(actions)+1
|
||||||
@@ -653,14 +520,14 @@ func makeAgentTable(agentList []string) *tview.Table {
|
|||||||
for r := 0; r < rows; r++ {
|
for r := 0; r < rows; r++ {
|
||||||
for c := 0; c < cols; c++ {
|
for c := 0; c < cols; c++ {
|
||||||
color := tcell.ColorWhite
|
color := tcell.ColorWhite
|
||||||
switch {
|
switch c {
|
||||||
case c < 1:
|
case 0:
|
||||||
chatActTable.SetCell(r, c,
|
chatActTable.SetCell(r, c,
|
||||||
tview.NewTableCell(agentList[r]).
|
tview.NewTableCell(agentList[r]).
|
||||||
SetTextColor(color).
|
SetTextColor(color).
|
||||||
SetAlign(tview.AlignCenter).
|
SetAlign(tview.AlignCenter).
|
||||||
SetSelectable(false))
|
SetSelectable(false))
|
||||||
case c == 1:
|
case 1:
|
||||||
if actions[c-1] == "filepath" {
|
if actions[c-1] == "filepath" {
|
||||||
cc, ok := sysMap[agentList[r]]
|
cc, ok := sysMap[agentList[r]]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -952,6 +819,7 @@ func makeFilePicker() *tview.Flex {
|
|||||||
// --- NEW: search state ---
|
// --- NEW: search state ---
|
||||||
searching := false
|
searching := false
|
||||||
searchQuery := ""
|
searchQuery := ""
|
||||||
|
searchInputMode := false
|
||||||
// Helper function to check if a file has an allowed extension from config
|
// Helper function to check if a file has an allowed extension from config
|
||||||
hasAllowedExtension := func(filename string) bool {
|
hasAllowedExtension := func(filename string) bool {
|
||||||
if cfg.FilePickerExts == "" {
|
if cfg.FilePickerExts == "" {
|
||||||
@@ -1144,6 +1012,7 @@ func makeFilePicker() *tview.Flex {
|
|||||||
case tcell.KeyEsc:
|
case tcell.KeyEsc:
|
||||||
// Exit search, clear filter
|
// Exit search, clear filter
|
||||||
searching = false
|
searching = false
|
||||||
|
searchInputMode = false
|
||||||
searchQuery = ""
|
searchQuery = ""
|
||||||
refreshList(currentDisplayDir, "")
|
refreshList(currentDisplayDir, "")
|
||||||
return nil
|
return nil
|
||||||
@@ -1153,16 +1022,80 @@ func makeFilePicker() *tview.Flex {
|
|||||||
refreshList(currentDisplayDir, searchQuery)
|
refreshList(currentDisplayDir, searchQuery)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case tcell.KeyRune:
|
case tcell.KeyEnter:
|
||||||
r := event.Rune()
|
// Exit search input mode and let normal processing handle selection
|
||||||
if r != 0 {
|
searchInputMode = false
|
||||||
searchQuery += string(r)
|
// Get the currently highlighted item in the list
|
||||||
refreshList(currentDisplayDir, searchQuery)
|
itemIndex := listView.GetCurrentItem()
|
||||||
|
if itemIndex >= 0 && itemIndex < listView.GetItemCount() {
|
||||||
|
itemText, _ := listView.GetItemText(itemIndex)
|
||||||
|
// Check for the exit option first
|
||||||
|
if strings.HasPrefix(itemText, "Exit file picker") {
|
||||||
|
pages.RemovePage(filePickerPage)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Extract the actual filename/directory name by removing the type info
|
||||||
|
actualItemName := itemText
|
||||||
|
if bracketPos := strings.Index(itemText, " ["); bracketPos != -1 {
|
||||||
|
actualItemName = itemText[:bracketPos]
|
||||||
|
}
|
||||||
|
// Check if it's a directory (ends with /)
|
||||||
|
if strings.HasSuffix(actualItemName, "/") {
|
||||||
|
var targetDir string
|
||||||
|
if strings.HasPrefix(actualItemName, "../") {
|
||||||
|
// Parent directory
|
||||||
|
targetDir = path.Dir(currentDisplayDir)
|
||||||
|
if targetDir == currentDisplayDir && currentDisplayDir == "/" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular subdirectory
|
||||||
|
dirName := strings.TrimSuffix(actualItemName, "/")
|
||||||
|
targetDir = path.Join(currentDisplayDir, dirName)
|
||||||
|
}
|
||||||
|
// Navigate – clear search
|
||||||
|
if cfg.ImagePreview && imgPreview != nil {
|
||||||
|
imgPreview.SetImage(nil)
|
||||||
|
}
|
||||||
|
searching = false
|
||||||
|
searchInputMode = false
|
||||||
|
searchQuery = ""
|
||||||
|
refreshList(targetDir, "")
|
||||||
|
dirStack = append(dirStack, targetDir)
|
||||||
|
currentStackPos = len(dirStack) - 1
|
||||||
|
statusView.SetText("Current: " + targetDir)
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
// It's a file
|
||||||
|
filePath := path.Join(currentDisplayDir, actualItemName)
|
||||||
|
if info, err := os.Stat(filePath); err == nil && !info.IsDir() {
|
||||||
|
if isImageFile(actualItemName) {
|
||||||
|
SetImageAttachment(filePath)
|
||||||
|
statusView.SetText("Image attached: " + filePath + " (will be sent with next message)")
|
||||||
|
pages.RemovePage(filePickerPage)
|
||||||
|
} else {
|
||||||
|
textArea.SetText(filePath, true)
|
||||||
|
app.SetFocus(textArea)
|
||||||
|
pages.RemovePage(filePickerPage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
case tcell.KeyRune:
|
||||||
|
r := event.Rune()
|
||||||
|
if searchInputMode && r != 0 {
|
||||||
|
searchQuery += string(r)
|
||||||
|
refreshList(currentDisplayDir, searchQuery)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// If not in search input mode, pass through for navigation
|
||||||
|
return event
|
||||||
default:
|
default:
|
||||||
// Pass all other keys (arrows, Enter, etc.) to normal processing
|
// Exit search input mode but keep filter active for navigation
|
||||||
// This allows selecting items while still in search mode
|
searchInputMode = false
|
||||||
|
// Pass all other keys (arrows, etc.) to normal processing
|
||||||
return event
|
return event
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1190,6 +1123,7 @@ func makeFilePicker() *tview.Flex {
|
|||||||
if event.Rune() == '/' {
|
if event.Rune() == '/' {
|
||||||
// Enter search mode
|
// Enter search mode
|
||||||
searching = true
|
searching = true
|
||||||
|
searchInputMode = true
|
||||||
searchQuery = ""
|
searchQuery = ""
|
||||||
refreshList(currentDisplayDir, "")
|
refreshList(currentDisplayDir, "")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
71
tools.go
71
tools.go
@@ -17,6 +17,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gf-lt/rag"
|
"gf-lt/rag"
|
||||||
|
|
||||||
"github.com/GrailFinder/searchagent/searcher"
|
"github.com/GrailFinder/searchagent/searcher"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,7 +173,6 @@ func init() {
|
|||||||
panic("failed to init seachagent; error: " + err.Error())
|
panic("failed to init seachagent; error: " + err.Error())
|
||||||
}
|
}
|
||||||
WebSearcher = sa
|
WebSearcher = sa
|
||||||
|
|
||||||
if err := rag.Init(cfg, logger, store); err != nil {
|
if err := rag.Init(cfg, logger, store); err != nil {
|
||||||
logger.Warn("failed to init rag; rag_search tool will not be available", "error", err)
|
logger.Warn("failed to init rag; rag_search tool will not be available", "error", err)
|
||||||
}
|
}
|
||||||
@@ -265,21 +265,18 @@ func ragsearch(args map[string]string) []byte {
|
|||||||
"limit_arg", limitS, "error", err)
|
"limit_arg", limitS, "error", err)
|
||||||
limit = 3
|
limit = 3
|
||||||
}
|
}
|
||||||
|
|
||||||
ragInstance := rag.GetInstance()
|
ragInstance := rag.GetInstance()
|
||||||
if ragInstance == nil {
|
if ragInstance == nil {
|
||||||
msg := "rag not initialized; rag_search tool is not available"
|
msg := "rag not initialized; rag_search tool is not available"
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := ragInstance.Search(query, limit)
|
results, err := ragInstance.Search(query, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "rag search failed; error: " + err.Error()
|
msg := "rag search failed; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(results)
|
data, err := json.Marshal(results)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to marshal rag search result; error: " + err.Error()
|
msg := "failed to marshal rag search result; error: " + err.Error()
|
||||||
@@ -419,7 +416,6 @@ func recallTopics(args map[string]string) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// File Manipulation Tools
|
// File Manipulation Tools
|
||||||
|
|
||||||
func fileCreate(args map[string]string) []byte {
|
func fileCreate(args map[string]string) []byte {
|
||||||
path, ok := args["path"]
|
path, ok := args["path"]
|
||||||
if !ok || path == "" {
|
if !ok || path == "" {
|
||||||
@@ -427,20 +423,16 @@ func fileCreate(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
content, ok := args["content"]
|
content, ok := args["content"]
|
||||||
if !ok {
|
if !ok {
|
||||||
content = ""
|
content = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeStringToFile(path, content); err != nil {
|
if err := writeStringToFile(path, content); err != nil {
|
||||||
msg := "failed to create file; error: " + err.Error()
|
msg := "failed to create file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := "file created successfully at " + path
|
msg := "file created successfully at " + path
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -452,16 +444,13 @@ func fileRead(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
content, err := readStringFromFile(path)
|
content, err := readStringFromFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to read file; error: " + err.Error()
|
msg := "failed to read file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"content": content,
|
"content": content,
|
||||||
"path": path,
|
"path": path,
|
||||||
@@ -472,7 +461,6 @@ func fileRead(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -525,15 +513,12 @@ func fileDelete(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
if err := removeFile(path); err != nil {
|
if err := removeFile(path); err != nil {
|
||||||
msg := "failed to delete file; error: " + err.Error()
|
msg := "failed to delete file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := "file deleted successfully at " + path
|
msg := "file deleted successfully at " + path
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -546,7 +531,6 @@ func fileMove(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
src = resolvePath(src)
|
src = resolvePath(src)
|
||||||
|
|
||||||
dst, ok := args["dst"]
|
dst, ok := args["dst"]
|
||||||
if !ok || dst == "" {
|
if !ok || dst == "" {
|
||||||
msg := "destination path not provided to file_move tool"
|
msg := "destination path not provided to file_move tool"
|
||||||
@@ -554,13 +538,11 @@ func fileMove(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
dst = resolvePath(dst)
|
dst = resolvePath(dst)
|
||||||
|
|
||||||
if err := moveFile(src, dst); err != nil {
|
if err := moveFile(src, dst); err != nil {
|
||||||
msg := "failed to move file; error: " + err.Error()
|
msg := "failed to move file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := fmt.Sprintf("file moved successfully from %s to %s", src, dst)
|
msg := fmt.Sprintf("file moved successfully from %s to %s", src, dst)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -573,7 +555,6 @@ func fileCopy(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
src = resolvePath(src)
|
src = resolvePath(src)
|
||||||
|
|
||||||
dst, ok := args["dst"]
|
dst, ok := args["dst"]
|
||||||
if !ok || dst == "" {
|
if !ok || dst == "" {
|
||||||
msg := "destination path not provided to file_copy tool"
|
msg := "destination path not provided to file_copy tool"
|
||||||
@@ -581,13 +562,11 @@ func fileCopy(args map[string]string) []byte {
|
|||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
dst = resolvePath(dst)
|
dst = resolvePath(dst)
|
||||||
|
|
||||||
if err := copyFile(src, dst); err != nil {
|
if err := copyFile(src, dst); err != nil {
|
||||||
msg := "failed to copy file; error: " + err.Error()
|
msg := "failed to copy file; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := fmt.Sprintf("file copied successfully from %s to %s", src, dst)
|
msg := fmt.Sprintf("file copied successfully from %s to %s", src, dst)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
@@ -597,16 +576,13 @@ func fileList(args map[string]string) []byte {
|
|||||||
if !ok || path == "" {
|
if !ok || path == "" {
|
||||||
path = "." // default to current directory
|
path = "." // default to current directory
|
||||||
}
|
}
|
||||||
|
|
||||||
path = resolvePath(path)
|
path = resolvePath(path)
|
||||||
|
|
||||||
files, err := listDirectory(path)
|
files, err := listDirectory(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to list directory; error: " + err.Error()
|
msg := "failed to list directory; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]interface{}{
|
result := map[string]interface{}{
|
||||||
"directory": path,
|
"directory": path,
|
||||||
"files": files,
|
"files": files,
|
||||||
@@ -617,12 +593,10 @@ func fileList(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions for file operations
|
// Helper functions for file operations
|
||||||
|
|
||||||
func resolvePath(p string) string {
|
func resolvePath(p string) string {
|
||||||
if filepath.IsAbs(p) {
|
if filepath.IsAbs(p) {
|
||||||
return p
|
return p
|
||||||
@@ -648,7 +622,6 @@ func appendStringToFile(filename string, data string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
_, err = file.WriteString(data)
|
_, err = file.WriteString(data)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -672,13 +645,11 @@ func copyFile(src, dst string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer srcFile.Close()
|
defer srcFile.Close()
|
||||||
|
|
||||||
dstFile, err := os.Create(dst)
|
dstFile, err := os.Create(dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer dstFile.Close()
|
defer dstFile.Close()
|
||||||
|
|
||||||
_, err = io.Copy(dstFile, srcFile)
|
_, err = io.Copy(dstFile, srcFile)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -697,7 +668,6 @@ func listDirectory(path string) ([]string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if entry.IsDir() {
|
if entry.IsDir() {
|
||||||
@@ -706,12 +676,10 @@ func listDirectory(path string) ([]string, error) {
|
|||||||
files = append(files, entry.Name())
|
files = append(files, entry.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return files, nil
|
return files, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Command Execution Tool
|
// Command Execution Tool
|
||||||
|
|
||||||
func executeCommand(args map[string]string) []byte {
|
func executeCommand(args map[string]string) []byte {
|
||||||
command, ok := args["command"]
|
command, ok := args["command"]
|
||||||
if !ok || command == "" {
|
if !ok || command == "" {
|
||||||
@@ -719,7 +687,6 @@ func executeCommand(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get arguments - handle both single arg and multiple args
|
// Get arguments - handle both single arg and multiple args
|
||||||
var cmdArgs []string
|
var cmdArgs []string
|
||||||
if args["args"] != "" {
|
if args["args"] != "" {
|
||||||
@@ -738,43 +705,36 @@ func executeCommand(args map[string]string) []byte {
|
|||||||
argNum++
|
argNum++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isCommandAllowed(command, cmdArgs...) {
|
if !isCommandAllowed(command, cmdArgs...) {
|
||||||
msg := fmt.Sprintf("command '%s' is not allowed", command)
|
msg := fmt.Sprintf("command '%s' is not allowed", command)
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute with timeout for safety
|
// Execute with timeout for safety
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
cmd := exec.CommandContext(ctx, command, cmdArgs...)
|
cmd := exec.CommandContext(ctx, command, cmdArgs...)
|
||||||
|
|
||||||
output, err := cmd.CombinedOutput()
|
output, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := fmt.Sprintf("command '%s' failed; error: %v; output: %s", command, err, string(output))
|
msg := fmt.Sprintf("command '%s' failed; error: %v; output: %s", command, err, string(output))
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if output is empty and return success message
|
// Check if output is empty and return success message
|
||||||
if len(output) == 0 {
|
if len(output) == 0 {
|
||||||
successMsg := fmt.Sprintf("command '%s %s' executed successfully and exited with code 0", command, strings.Join(cmdArgs, " "))
|
successMsg := fmt.Sprintf("command '%s %s' executed successfully and exited with code 0", command, strings.Join(cmdArgs, " "))
|
||||||
return []byte(successMsg)
|
return []byte(successMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions for command execution
|
// Helper functions for command execution
|
||||||
|
|
||||||
// Todo structure
|
// Todo structure
|
||||||
type TodoItem struct {
|
type TodoItem struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Task string `json:"task"`
|
Task string `json:"task"`
|
||||||
Status string `json:"status"` // "pending", "in_progress", "completed"
|
Status string `json:"status"` // "pending", "in_progress", "completed"
|
||||||
}
|
}
|
||||||
|
|
||||||
type TodoList struct {
|
type TodoList struct {
|
||||||
Items []TodoItem `json:"items"`
|
Items []TodoItem `json:"items"`
|
||||||
}
|
}
|
||||||
@@ -792,32 +752,26 @@ func todoCreate(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate simple ID
|
// Generate simple ID
|
||||||
id := fmt.Sprintf("todo_%d", len(globalTodoList.Items)+1)
|
id := fmt.Sprintf("todo_%d", len(globalTodoList.Items)+1)
|
||||||
|
|
||||||
newItem := TodoItem{
|
newItem := TodoItem{
|
||||||
ID: id,
|
ID: id,
|
||||||
Task: task,
|
Task: task,
|
||||||
Status: "pending",
|
Status: "pending",
|
||||||
}
|
}
|
||||||
|
|
||||||
globalTodoList.Items = append(globalTodoList.Items, newItem)
|
globalTodoList.Items = append(globalTodoList.Items, newItem)
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"message": "todo created successfully",
|
"message": "todo created successfully",
|
||||||
"id": id,
|
"id": id,
|
||||||
"task": task,
|
"task": task,
|
||||||
"status": "pending",
|
"status": "pending",
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, err := json.Marshal(result)
|
jsonResult, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to marshal result; error: " + err.Error()
|
msg := "failed to marshal result; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -851,7 +805,6 @@ func todoRead(args map[string]string) []byte {
|
|||||||
}
|
}
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return all todos if no ID specified
|
// Return all todos if no ID specified
|
||||||
result := map[string]interface{}{
|
result := map[string]interface{}{
|
||||||
"todos": globalTodoList.Items,
|
"todos": globalTodoList.Items,
|
||||||
@@ -862,7 +815,6 @@ func todoRead(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -873,16 +825,13 @@ func todoUpdate(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
task, taskOk := args["task"]
|
task, taskOk := args["task"]
|
||||||
status, statusOk := args["status"]
|
status, statusOk := args["status"]
|
||||||
|
|
||||||
if !taskOk && !statusOk {
|
if !taskOk && !statusOk {
|
||||||
msg := "neither task nor status provided to todo_update tool"
|
msg := "neither task nor status provided to todo_update tool"
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find and update the todo
|
// Find and update the todo
|
||||||
for i, item := range globalTodoList.Items {
|
for i, item := range globalTodoList.Items {
|
||||||
if item.ID == id {
|
if item.ID == id {
|
||||||
@@ -906,23 +855,19 @@ func todoUpdate(args map[string]string) []byte {
|
|||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"message": "todo updated successfully",
|
"message": "todo updated successfully",
|
||||||
"id": id,
|
"id": id,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, err := json.Marshal(result)
|
jsonResult, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to marshal result; error: " + err.Error()
|
msg := "failed to marshal result; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID not found
|
// ID not found
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"error": fmt.Sprintf("todo with id %s not found", id),
|
"error": fmt.Sprintf("todo with id %s not found", id),
|
||||||
@@ -943,29 +888,24 @@ func todoDelete(args map[string]string) []byte {
|
|||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find and remove the todo
|
// Find and remove the todo
|
||||||
for i, item := range globalTodoList.Items {
|
for i, item := range globalTodoList.Items {
|
||||||
if item.ID == id {
|
if item.ID == id {
|
||||||
// Remove item from slice
|
// Remove item from slice
|
||||||
globalTodoList.Items = append(globalTodoList.Items[:i], globalTodoList.Items[i+1:]...)
|
globalTodoList.Items = append(globalTodoList.Items[:i], globalTodoList.Items[i+1:]...)
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"message": "todo deleted successfully",
|
"message": "todo deleted successfully",
|
||||||
"id": id,
|
"id": id,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, err := json.Marshal(result)
|
jsonResult, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "failed to marshal result; error: " + err.Error()
|
msg := "failed to marshal result; error: " + err.Error()
|
||||||
logger.Error(msg)
|
logger.Error(msg)
|
||||||
return []byte(msg)
|
return []byte(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonResult
|
return jsonResult
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID not found
|
// ID not found
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
"error": fmt.Sprintf("todo with id %s not found", id),
|
"error": fmt.Sprintf("todo with id %s not found", id),
|
||||||
@@ -1239,7 +1179,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_create
|
// file_create
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1262,7 +1201,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_read
|
// file_read
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1281,7 +1219,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_write
|
// file_write
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1304,7 +1241,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_write_append
|
// file_write_append
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1327,7 +1263,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_delete
|
// file_delete
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1346,7 +1281,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_move
|
// file_move
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1369,7 +1303,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_copy
|
// file_copy
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1392,7 +1325,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// file_list
|
// file_list
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
@@ -1411,7 +1343,6 @@ var baseTools = []models.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// execute_command
|
// execute_command
|
||||||
models.Tool{
|
models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
|
|||||||
20
tui.go
20
tui.go
@@ -264,7 +264,7 @@ func init() {
|
|||||||
pages.RemovePage(editMsgPage)
|
pages.RemovePage(editMsgPage)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
chatBody.Messages[selectedIndex].Content = editedMsg
|
chatBody.Messages[selectedIndex].SetText(editedMsg)
|
||||||
// change textarea
|
// change textarea
|
||||||
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys))
|
textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys))
|
||||||
pages.RemovePage(editMsgPage)
|
pages.RemovePage(editMsgPage)
|
||||||
@@ -352,13 +352,14 @@ func init() {
|
|||||||
case editMode:
|
case editMode:
|
||||||
hideIndexBar() // Hide overlay first
|
hideIndexBar() // Hide overlay first
|
||||||
pages.AddPage(editMsgPage, editArea, true, true)
|
pages.AddPage(editMsgPage, editArea, true, true)
|
||||||
editArea.SetText(m.Content, true)
|
editArea.SetText(m.GetText(), true)
|
||||||
default:
|
default:
|
||||||
if err := copyToClipboard(m.Content); err != nil {
|
msgText := m.GetText()
|
||||||
|
if err := copyToClipboard(msgText); err != nil {
|
||||||
logger.Error("failed to copy to clipboard", "error", err)
|
logger.Error("failed to copy to clipboard", "error", err)
|
||||||
}
|
}
|
||||||
previewLen := min(30, len(m.Content))
|
previewLen := min(30, len(msgText))
|
||||||
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", m.Content[:previewLen])
|
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", msgText[:previewLen])
|
||||||
if err := notifyUser("copied", notification); err != nil {
|
if err := notifyUser("copied", notification); err != nil {
|
||||||
logger.Error("failed to send notification", "error", err)
|
logger.Error("failed to send notification", "error", err)
|
||||||
}
|
}
|
||||||
@@ -648,11 +649,12 @@ func init() {
|
|||||||
// copy msg to clipboard
|
// copy msg to clipboard
|
||||||
editMode = false
|
editMode = false
|
||||||
m := chatBody.Messages[len(chatBody.Messages)-1]
|
m := chatBody.Messages[len(chatBody.Messages)-1]
|
||||||
if err := copyToClipboard(m.Content); err != nil {
|
msgText := m.GetText()
|
||||||
|
if err := copyToClipboard(msgText); err != nil {
|
||||||
logger.Error("failed to copy to clipboard", "error", err)
|
logger.Error("failed to copy to clipboard", "error", err)
|
||||||
}
|
}
|
||||||
previewLen := min(30, len(m.Content))
|
previewLen := min(30, len(msgText))
|
||||||
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", m.Content[:previewLen])
|
notification := fmt.Sprintf("msg '%s' was copied to the clipboard", msgText[:previewLen])
|
||||||
if err := notifyUser("copied", notification); err != nil {
|
if err := notifyUser("copied", notification); err != nil {
|
||||||
logger.Error("failed to send notification", "error", err)
|
logger.Error("failed to send notification", "error", err)
|
||||||
}
|
}
|
||||||
@@ -847,7 +849,7 @@ func init() {
|
|||||||
// Stop any currently playing TTS first
|
// Stop any currently playing TTS first
|
||||||
TTSDoneChan <- true
|
TTSDoneChan <- true
|
||||||
lastMsg := chatBody.Messages[len(chatBody.Messages)-1]
|
lastMsg := chatBody.Messages[len(chatBody.Messages)-1]
|
||||||
cleanedText := models.CleanText(lastMsg.Content)
|
cleanedText := models.CleanText(lastMsg.GetText())
|
||||||
if cleanedText != "" {
|
if cleanedText != "" {
|
||||||
// nolint: errcheck
|
// nolint: errcheck
|
||||||
go orator.Speak(cleanedText)
|
go orator.Speak(cleanedText)
|
||||||
|
|||||||
Reference in New Issue
Block a user