Feat: rag tool

This commit is contained in:
Grail Finder
2026-02-24 20:24:44 +03:00
parent 27288e2aaa
commit 6c03a1a277
3 changed files with 383 additions and 100 deletions

98
llm.go
View File

@@ -11,7 +11,6 @@ import (
var imageAttachmentPath string // Global variable to track image attachment for next message var imageAttachmentPath string // Global variable to track image attachment for next message
var lastImg string // for ctrl+j var lastImg string // for ctrl+j
var RAGMsg = "Retrieved context for user's query:\n"
// containsToolSysMsg checks if the toolSysMsg already exists in the chat body // containsToolSysMsg checks if the toolSysMsg already exists in the chat body
func containsToolSysMsg() bool { func containsToolSysMsg() bool {
@@ -142,22 +141,6 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.Messages = append(chatBody.Messages, newMsg)
} }
// if rag - add as system message to avoid conflicts with tool usage
if !resume && cfg.RAGEnabled {
um := chatBody.Messages[len(chatBody.Messages)-1].Content
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
ragResp, err := chatRagUse(um)
if err != nil {
logger.Error("failed to form a rag msg", "error", err)
return nil, err
}
logger.Debug("RAG response received", "response_len", len(ragResp),
"response_preview", ragResp[:min(len(ragResp), 100)])
// Use system role for RAG context to avoid conflicts with tool usage
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
}
// sending description of the tools and how to use them // sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
@@ -301,23 +284,6 @@ func (op LCPChat) FormMsg(msg, role string, resume bool) (io.Reader, error) {
logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role, logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role,
"content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages)) "content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages))
} }
// if rag - add as system message to avoid conflicts with tool usage
if !resume && cfg.RAGEnabled {
um := chatBody.Messages[len(chatBody.Messages)-1].Content
logger.Debug("LCPChat: RAG is enabled, preparing RAG context", "user_message", um)
ragResp, err := chatRagUse(um)
if err != nil {
logger.Error("LCPChat: failed to form a rag msg", "error", err)
return nil, err
}
logger.Debug("LCPChat: RAG response received",
"response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
// Use system role for RAG context to avoid conflicts with tool usage
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
logger.Debug("LCPChat: RAG message added to chat body", "role", ragMsg.Role,
"rag_content_len", len(ragMsg.Content), "message_count_after_rag", len(chatBody.Messages))
}
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
// openai /v1/chat does not support custom roles; needs to be user, assistant, system // openai /v1/chat does not support custom roles; needs to be user, assistant, system
// Add persona suffix to the last user message to indicate who the assistant should reply as // Add persona suffix to the last user message to indicate who the assistant should reply as
@@ -389,22 +355,6 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.Messages = append(chatBody.Messages, newMsg)
} }
// if rag - add as system message to avoid conflicts with tool usage
if !resume && cfg.RAGEnabled {
um := chatBody.Messages[len(chatBody.Messages)-1].Content
logger.Debug("DeepSeekerCompletion: RAG is enabled, preparing RAG context", "user_message", um)
ragResp, err := chatRagUse(um)
if err != nil {
logger.Error("DeepSeekerCompletion: failed to form a rag msg", "error", err)
return nil, err
}
logger.Debug("DeepSeekerCompletion: RAG response received",
"response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
// Use system role for RAG context to avoid conflicts with tool usage
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
logger.Debug("DeepSeekerCompletion: RAG message added to chat body", "message_count", len(chatBody.Messages))
}
// sending description of the tools and how to use them // sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
@@ -474,22 +424,6 @@ func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.Messages = append(chatBody.Messages, newMsg)
} }
// if rag - add as system message to avoid conflicts with tool usage
if !resume && cfg.RAGEnabled {
um := chatBody.Messages[len(chatBody.Messages)-1].Content
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
ragResp, err := chatRagUse(um)
if err != nil {
logger.Error("failed to form a rag msg", "error", err)
return nil, err
}
logger.Debug("RAG response received", "response_len", len(ragResp),
"response_preview", ragResp[:min(len(ragResp), 100)])
// Use system role for RAG context to avoid conflicts with tool usage
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
}
// Create copy of chat body with standardized user role // Create copy of chat body with standardized user role
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
// Add persona suffix to the last user message to indicate who the assistant should reply as // Add persona suffix to the last user message to indicate who the assistant should reply as
@@ -552,22 +486,6 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.Messages = append(chatBody.Messages, newMsg)
} }
// if rag - add as system message to avoid conflicts with tool usage
if !resume && cfg.RAGEnabled {
um := chatBody.Messages[len(chatBody.Messages)-1].Content
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
ragResp, err := chatRagUse(um)
if err != nil {
logger.Error("failed to form a rag msg", "error", err)
return nil, err
}
logger.Debug("RAG response received", "response_len",
len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
// Use system role for RAG context to avoid conflicts with tool usage
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
}
// sending description of the tools and how to use them // sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
@@ -670,22 +588,6 @@ func (or OpenRouterChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = *processMessageTag(&newMsg) newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg) chatBody.Messages = append(chatBody.Messages, newMsg)
} }
// if rag - add as system message to avoid conflicts with tool usage
if !resume && cfg.RAGEnabled {
um := chatBody.Messages[len(chatBody.Messages)-1].Content
logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
ragResp, err := chatRagUse(um)
if err != nil {
logger.Error("failed to form a rag msg", "error", err)
return nil, err
}
logger.Debug("RAG response received", "response_len", len(ragResp),
"response_preview", ragResp[:min(len(ragResp), 100)])
// Use system role for RAG context to avoid conflicts with tool usage
ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
}
// Create copy of chat body with standardized user role // Create copy of chat body with standardized user role
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
// Add persona suffix to the last user message to indicate who the assistant should reply as // Add persona suffix to the last user message to indicate who the assistant should reply as

View File

@@ -9,6 +9,8 @@ import (
"log/slog" "log/slog"
"os" "os"
"path" "path"
"regexp"
"sort"
"strings" "strings"
"sync" "sync"
@@ -195,3 +197,309 @@ func (r *RAG) ListLoaded() ([]string, error) {
func (r *RAG) RemoveFile(filename string) error { func (r *RAG) RemoveFile(filename string) error {
return r.storage.RemoveEmbByFileName(filename) return r.storage.RemoveEmbByFileName(filename)
} }
var (
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
)
func (r *RAG) RefineQuery(query string) string {
original := query
query = strings.TrimSpace(query)
if len(query) == 0 {
return original
}
if len(query) <= 3 {
return original
}
query = strings.ToLower(query)
for _, stopWord := range stopWords {
wordPattern := `\b` + stopWord + `\b`
re := regexp.MustCompile(wordPattern)
query = re.ReplaceAllString(query, "")
}
query = strings.TrimSpace(query)
if len(query) < 5 {
return original
}
if queryRefinementPattern.MatchString(original) {
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
cleaned = strings.TrimSpace(cleaned)
if len(cleaned) >= 5 {
return cleaned
}
}
query = r.extractImportantPhrases(query)
if len(query) < 5 {
return original
}
return query
}
func (r *RAG) extractImportantPhrases(query string) string {
words := strings.Fields(query)
var important []string
for _, word := range words {
word = strings.Trim(word, ".,!?;:'\"()[]{}")
isImportant := false
for _, kw := range importantKeywords {
if strings.Contains(strings.ToLower(word), kw) {
isImportant = true
break
}
}
if isImportant || len(word) > 3 {
important = append(important, word)
}
}
if len(important) == 0 {
return query
}
return strings.Join(important, " ")
}
func (r *RAG) GenerateQueryVariations(query string) []string {
variations := []string{query}
if len(query) < 5 {
return variations
}
parts := strings.Fields(query)
if len(parts) == 0 {
return variations
}
if len(parts) >= 2 {
trimmed := strings.Join(parts[:len(parts)-1], " ")
if len(trimmed) >= 5 {
variations = append(variations, trimmed)
}
}
if len(parts) >= 2 {
trimmed := strings.Join(parts[1:], " ")
if len(trimmed) >= 5 {
variations = append(variations, trimmed)
}
}
if !strings.HasSuffix(query, " explanation") {
variations = append(variations, query+" explanation")
}
if !strings.HasPrefix(query, "what is ") {
variations = append(variations, "what is "+query)
}
if !strings.HasSuffix(query, " details") {
variations = append(variations, query+" details")
}
if !strings.HasSuffix(query, " summary") {
variations = append(variations, query+" summary")
}
return variations
}
func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
type scoredResult struct {
row models.VectorRow
distance float32
}
scored := make([]scoredResult, 0, len(results))
for i := range results {
row := results[i]
score := float32(0)
rawTextLower := strings.ToLower(row.RawText)
queryLower := strings.ToLower(query)
if strings.Contains(rawTextLower, queryLower) {
score += 10
}
queryWords := strings.Fields(queryLower)
matchCount := 0
for _, word := range queryWords {
if len(word) > 2 && strings.Contains(rawTextLower, word) {
matchCount++
}
}
if len(queryWords) > 0 {
score += float32(matchCount) / float32(len(queryWords)) * 5
}
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
score += 3
}
distance := row.Distance - score/100
scored = append(scored, scoredResult{row: row, distance: distance})
}
sort.Slice(scored, func(i, j int) bool {
return scored[i].distance < scored[j].distance
})
unique := make([]models.VectorRow, 0)
seen := make(map[string]bool)
for i := range scored {
if !seen[scored[i].row.Slug] {
seen[scored[i].row.Slug] = true
unique = append(unique, scored[i].row)
}
}
if len(unique) > 10 {
unique = unique[:10]
}
return unique
}
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
if len(results) == 0 {
return "No relevant information found in the vector database.", nil
}
var contextBuilder strings.Builder
contextBuilder.WriteString("User Query: ")
contextBuilder.WriteString(query)
contextBuilder.WriteString("\n\nRetrieved Context:\n")
for i, row := range results {
contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName))
contextBuilder.WriteString(row.RawText)
contextBuilder.WriteString("\n\n")
}
contextBuilder.WriteString("Instructions: ")
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
contextBuilder.WriteString("Extract only the most relevant information. ")
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
contextBuilder.WriteString("Cite sources by filename when relevant. ")
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
synthesisPrompt := contextBuilder.String()
emb, err := r.LineToVector(synthesisPrompt)
if err != nil {
r.logger.Error("failed to embed synthesis prompt", "error", err)
return "", err
}
embResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0,
}
topResults, err := r.SearchEmb(embResp)
if err != nil {
r.logger.Error("failed to search for synthesis context", "error", err)
return "", err
}
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
return topResults[0].RawText, nil
}
var finalAnswer strings.Builder
finalAnswer.WriteString("Based on the retrieved context:\n\n")
for i, row := range results {
if i >= 5 {
break
}
finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)))
}
return finalAnswer.String(), nil
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined)
allResults := make([]models.VectorRow, 0)
seen := make(map[string]bool)
for _, q := range variations {
emb, err := r.LineToVector(q)
if err != nil {
r.logger.Error("failed to embed query variation", "error", err, "query", q)
continue
}
embResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0,
}
results, err := r.SearchEmb(embResp)
if err != nil {
r.logger.Error("failed to search embeddings", "error", err, "query", q)
continue
}
for _, row := range results {
if !seen[row.Slug] {
seen[row.Slug] = true
allResults = append(allResults, row)
}
}
}
reranked := r.RerankResults(allResults, query)
if len(reranked) > limit {
reranked = reranked[:limit]
}
return reranked, nil
}
var (
ragInstance *RAG
ragOnce sync.Once
)
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
ragOnce.Do(func() {
if c == nil || l == nil || s == nil {
return
}
ragInstance = New(l, s, c)
})
return nil
}
func GetInstance() *RAG {
return ragInstance
}

View File

@@ -16,6 +16,7 @@ import (
"sync" "sync"
"time" "time"
"gf-lt/rag"
"github.com/GrailFinder/searchagent/searcher" "github.com/GrailFinder/searchagent/searcher"
) )
@@ -58,9 +59,9 @@ Your current tools:
"when_to_use": "when asked to search the web for information; returns clean summary without html,css and other web elements; limit is optional (default 3)" "when_to_use": "when asked to search the web for information; returns clean summary without html,css and other web elements; limit is optional (default 3)"
}, },
{ {
"name":"websearch_raw", "name":"rag_search",
"args": ["query", "limit"], "args": ["query", "limit"],
"when_to_use": "when asked to search the web for information; returns raw data as is without processing; limit is optional (default 3)" "when_to_use": "when asked to search the local document database for information; performs query refinement, semantic search, reranking, and synthesis; returns clean summary with sources; limit is optional (default 3)"
}, },
{ {
"name":"read_url", "name":"read_url",
@@ -146,6 +147,7 @@ under the topic: Adam's number is stored:
After that you are free to respond to the user. After that you are free to respond to the user.
` `
webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.` webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.`
ragSearchSysPrompt = `Synthesize the document search results, extracting key information and presenting a concise answer. Provide sources and document IDs where relevant.`
readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.` readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.`
summarySysPrompt = `Please provide a concise summary of the following conversation. Focus on key points, decisions, and actions. Provide only the summary, no additional commentary.` summarySysPrompt = `Please provide a concise summary of the following conversation. Focus on key points, decisions, and actions. Provide only the summary, no additional commentary.`
basicCard = &models.CharCard{ basicCard = &models.CharCard{
@@ -170,6 +172,10 @@ func init() {
panic("failed to init seachagent; error: " + err.Error()) panic("failed to init seachagent; error: " + err.Error())
} }
WebSearcher = sa WebSearcher = sa
if err := rag.Init(cfg, logger, store); err != nil {
logger.Warn("failed to init rag; rag_search tool will not be available", "error", err)
}
} }
// getWebAgentClient returns a singleton AgentClient for web agents. // getWebAgentClient returns a singleton AgentClient for web agents.
@@ -196,6 +202,8 @@ func getWebAgentClient() *agent.AgentClient {
func registerWebAgents() { func registerWebAgents() {
webAgentsOnce.Do(func() { webAgentsOnce.Do(func() {
client := getWebAgentClient() client := getWebAgentClient()
// Register rag_search agent
agent.Register("rag_search", agent.NewWebAgentB(client, ragSearchSysPrompt))
// Register websearch agent // Register websearch agent
agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt)) agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt))
// Register read_url agent // Register read_url agent
@@ -239,6 +247,48 @@ func websearch(args map[string]string) []byte {
return data return data
} }
// rag search (searches local document database)
func ragsearch(args map[string]string) []byte {
query, ok := args["query"]
if !ok || query == "" {
msg := "query not provided to rag_search tool"
logger.Error(msg)
return []byte(msg)
}
limitS, ok := args["limit"]
if !ok || limitS == "" {
limitS = "3"
}
limit, err := strconv.Atoi(limitS)
if err != nil || limit == 0 {
logger.Warn("ragsearch limit; passed bad value; setting to default (3)",
"limit_arg", limitS, "error", err)
limit = 3
}
ragInstance := rag.GetInstance()
if ragInstance == nil {
msg := "rag not initialized; rag_search tool is not available"
logger.Error(msg)
return []byte(msg)
}
results, err := ragInstance.Search(query, limit)
if err != nil {
msg := "rag search failed; error: " + err.Error()
logger.Error(msg)
return []byte(msg)
}
data, err := json.Marshal(results)
if err != nil {
msg := "failed to marshal rag search result; error: " + err.Error()
logger.Error(msg)
return []byte(msg)
}
return data
}
// web search raw (returns raw data without processing) // web search raw (returns raw data without processing)
func websearchRaw(args map[string]string) []byte { func websearchRaw(args map[string]string) []byte {
// make http request return bytes // make http request return bytes
@@ -997,6 +1047,7 @@ var fnMap = map[string]fnSig{
"recall": recall, "recall": recall,
"recall_topics": recallTopics, "recall_topics": recallTopics,
"memorise": memorise, "memorise": memorise,
"rag_search": ragsearch,
"websearch": websearch, "websearch": websearch,
"websearch_raw": websearchRaw, "websearch_raw": websearchRaw,
"read_url": readURL, "read_url": readURL,
@@ -1033,6 +1084,28 @@ func callToolWithAgent(name string, args map[string]string) []byte {
// openai style def // openai style def
var baseTools = []models.Tool{ var baseTools = []models.Tool{
// rag_search
models.Tool{
Type: "function",
Function: models.ToolFunc{
Name: "rag_search",
Description: "Search local document database given query, limit of sources (default 3). Performs query refinement, semantic search, reranking, and synthesis.",
Parameters: models.ToolFuncParams{
Type: "object",
Required: []string{"query", "limit"},
Properties: map[string]models.ToolArgProps{
"query": models.ToolArgProps{
Type: "string",
Description: "search query",
},
"limit": models.ToolArgProps{
Type: "string",
Description: "limit of the document results",
},
},
},
},
},
// websearch // websearch
models.Tool{ models.Tool{
Type: "function", Type: "function",