diff --git a/llm.go b/llm.go index 6697dfa..ebda29b 100644 --- a/llm.go +++ b/llm.go @@ -11,7 +11,6 @@ import ( var imageAttachmentPath string // Global variable to track image attachment for next message var lastImg string // for ctrl+j -var RAGMsg = "Retrieved context for user's query:\n" // containsToolSysMsg checks if the toolSysMsg already exists in the chat body func containsToolSysMsg() bool { @@ -142,22 +141,6 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro newMsg = *processMessageTag(&newMsg) chatBody.Messages = append(chatBody.Messages, newMsg) } - // if rag - add as system message to avoid conflicts with tool usage - if !resume && cfg.RAGEnabled { - um := chatBody.Messages[len(chatBody.Messages)-1].Content - logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) - ragResp, err := chatRagUse(um) - if err != nil { - logger.Error("failed to form a rag msg", "error", err) - return nil, err - } - logger.Debug("RAG response received", "response_len", len(ragResp), - "response_preview", ragResp[:min(len(ragResp), 100)]) - // Use system role for RAG context to avoid conflicts with tool usage - ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} - chatBody.Messages = append(chatBody.Messages, ragMsg) - logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) - } // sending description of the tools and how to use them if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) @@ -301,23 +284,6 @@ func (op LCPChat) FormMsg(msg, role string, resume bool) (io.Reader, error) { logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role, "content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages)) } - // if rag - add as system message to avoid conflicts with tool usage - if !resume && cfg.RAGEnabled { - um := chatBody.Messages[len(chatBody.Messages)-1].Content - logger.Debug("LCPChat: RAG is enabled, preparing RAG context", "user_message", um) - ragResp, err := chatRagUse(um) - if err != nil { - logger.Error("LCPChat: failed to form a rag msg", "error", err) - return nil, err - } - logger.Debug("LCPChat: RAG response received", - "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) - // Use system role for RAG context to avoid conflicts with tool usage - ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} - chatBody.Messages = append(chatBody.Messages, ragMsg) - logger.Debug("LCPChat: RAG message added to chat body", "role", ragMsg.Role, - "rag_content_len", len(ragMsg.Content), "message_count_after_rag", len(chatBody.Messages)) - } filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) // openai /v1/chat does not support custom roles; needs to be user, assistant, system // Add persona suffix to the last user message to indicate who the assistant should reply as @@ -389,22 +355,6 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader newMsg = *processMessageTag(&newMsg) chatBody.Messages = append(chatBody.Messages, newMsg) } - // if rag - add as system message to avoid conflicts with tool usage - if !resume && cfg.RAGEnabled { - um := chatBody.Messages[len(chatBody.Messages)-1].Content - logger.Debug("DeepSeekerCompletion: RAG is enabled, preparing RAG context", "user_message", um) - ragResp, err := chatRagUse(um) - if err != nil { - logger.Error("DeepSeekerCompletion: failed to form a rag msg", "error", err) - return nil, err - } - logger.Debug("DeepSeekerCompletion: RAG response received", - "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) - // Use system role for RAG context to avoid conflicts with tool usage - ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} - chatBody.Messages = append(chatBody.Messages, ragMsg) - logger.Debug("DeepSeekerCompletion: RAG message added to chat body", "message_count", len(chatBody.Messages)) - } // sending description of the tools and how to use them if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) @@ -474,22 +424,6 @@ func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, erro newMsg = *processMessageTag(&newMsg) chatBody.Messages = append(chatBody.Messages, newMsg) } - // if rag - add as system message to avoid conflicts with tool usage - if !resume && cfg.RAGEnabled { - um := chatBody.Messages[len(chatBody.Messages)-1].Content - logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) - ragResp, err := chatRagUse(um) - if err != nil { - logger.Error("failed to form a rag msg", "error", err) - return nil, err - } - logger.Debug("RAG response received", "response_len", len(ragResp), - "response_preview", ragResp[:min(len(ragResp), 100)]) - // Use system role for RAG context to avoid conflicts with tool usage - ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} - chatBody.Messages = append(chatBody.Messages, ragMsg) - logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) - } // Create copy of chat body with standardized user role filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) // Add persona suffix to the last user message to indicate who the assistant should reply as @@ -552,22 +486,6 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader newMsg = *processMessageTag(&newMsg) chatBody.Messages = append(chatBody.Messages, newMsg) } - // if rag - add as system message to avoid conflicts with tool usage - if !resume && cfg.RAGEnabled { - um := chatBody.Messages[len(chatBody.Messages)-1].Content - logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) - ragResp, err := chatRagUse(um) - if err != nil { - logger.Error("failed to form a rag msg", "error", err) - return nil, err - } - logger.Debug("RAG response received", "response_len", - len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) - // Use system role for RAG context to avoid conflicts with tool usage - ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} - chatBody.Messages = append(chatBody.Messages, ragMsg) - logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) - } // sending description of the tools and how to use them if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() { chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) @@ -670,22 +588,6 @@ func (or OpenRouterChat) FormMsg(msg, role string, resume bool) (io.Reader, erro newMsg = *processMessageTag(&newMsg) chatBody.Messages = append(chatBody.Messages, newMsg) } - // if rag - add as system message to avoid conflicts with tool usage - if !resume && cfg.RAGEnabled { - um := chatBody.Messages[len(chatBody.Messages)-1].Content - logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) - ragResp, err := chatRagUse(um) - if err != nil { - logger.Error("failed to form a rag msg", "error", err) - return nil, err - } - logger.Debug("RAG response received", "response_len", len(ragResp), - "response_preview", ragResp[:min(len(ragResp), 100)]) - // Use system role for RAG context to avoid conflicts with tool usage - ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} - chatBody.Messages = append(chatBody.Messages, ragMsg) - logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) - } // Create copy of chat body with standardized user role filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages) // Add persona suffix to the last user message to indicate who the assistant should reply as diff --git a/rag/rag.go b/rag/rag.go index f554924..b49bd97 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -9,6 +9,8 @@ import ( "log/slog" "os" "path" + "regexp" + "sort" "strings" "sync" @@ -195,3 +197,309 @@ func (r *RAG) ListLoaded() ([]string, error) { func (r *RAG) RemoveFile(filename string) error { return r.storage.RemoveEmbByFileName(filename) } + +var ( + queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`) + importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"} + stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"} +) + +func (r *RAG) RefineQuery(query string) string { + original := query + query = strings.TrimSpace(query) + + if len(query) == 0 { + return original + } + + if len(query) <= 3 { + return original + } + + query = strings.ToLower(query) + + for _, stopWord := range stopWords { + wordPattern := `\b` + stopWord + `\b` + re := regexp.MustCompile(wordPattern) + query = re.ReplaceAllString(query, "") + } + + query = strings.TrimSpace(query) + + if len(query) < 5 { + return original + } + + if queryRefinementPattern.MatchString(original) { + cleaned := queryRefinementPattern.ReplaceAllString(original, "") + cleaned = strings.TrimSpace(cleaned) + if len(cleaned) >= 5 { + return cleaned + } + } + + query = r.extractImportantPhrases(query) + + if len(query) < 5 { + return original + } + + return query +} + +func (r *RAG) extractImportantPhrases(query string) string { + words := strings.Fields(query) + + var important []string + for _, word := range words { + word = strings.Trim(word, ".,!?;:'\"()[]{}") + + isImportant := false + for _, kw := range importantKeywords { + if strings.Contains(strings.ToLower(word), kw) { + isImportant = true + break + } + } + + if isImportant || len(word) > 3 { + important = append(important, word) + } + } + + if len(important) == 0 { + return query + } + + return strings.Join(important, " ") +} + +func (r *RAG) GenerateQueryVariations(query string) []string { + variations := []string{query} + + if len(query) < 5 { + return variations + } + + parts := strings.Fields(query) + if len(parts) == 0 { + return variations + } + + if len(parts) >= 2 { + trimmed := strings.Join(parts[:len(parts)-1], " ") + if len(trimmed) >= 5 { + variations = append(variations, trimmed) + } + } + + if len(parts) >= 2 { + trimmed := strings.Join(parts[1:], " ") + if len(trimmed) >= 5 { + variations = append(variations, trimmed) + } + } + + if !strings.HasSuffix(query, " explanation") { + variations = append(variations, query+" explanation") + } + if !strings.HasPrefix(query, "what is ") { + variations = append(variations, "what is "+query) + } + if !strings.HasSuffix(query, " details") { + variations = append(variations, query+" details") + } + if !strings.HasSuffix(query, " summary") { + variations = append(variations, query+" summary") + } + + return variations +} + +func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow { + type scoredResult struct { + row models.VectorRow + distance float32 + } + + scored := make([]scoredResult, 0, len(results)) + + for i := range results { + row := results[i] + + score := float32(0) + + rawTextLower := strings.ToLower(row.RawText) + queryLower := strings.ToLower(query) + + if strings.Contains(rawTextLower, queryLower) { + score += 10 + } + + queryWords := strings.Fields(queryLower) + matchCount := 0 + for _, word := range queryWords { + if len(word) > 2 && strings.Contains(rawTextLower, word) { + matchCount++ + } + } + if len(queryWords) > 0 { + score += float32(matchCount) / float32(len(queryWords)) * 5 + } + + if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { + score += 3 + } + + distance := row.Distance - score/100 + + scored = append(scored, scoredResult{row: row, distance: distance}) + } + + sort.Slice(scored, func(i, j int) bool { + return scored[i].distance < scored[j].distance + }) + + unique := make([]models.VectorRow, 0) + seen := make(map[string]bool) + + for i := range scored { + if !seen[scored[i].row.Slug] { + seen[scored[i].row.Slug] = true + unique = append(unique, scored[i].row) + } + } + + if len(unique) > 10 { + unique = unique[:10] + } + + return unique +} + +func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) { + if len(results) == 0 { + return "No relevant information found in the vector database.", nil + } + + var contextBuilder strings.Builder + contextBuilder.WriteString("User Query: ") + contextBuilder.WriteString(query) + contextBuilder.WriteString("\n\nRetrieved Context:\n") + + for i, row := range results { + contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName)) + contextBuilder.WriteString(row.RawText) + contextBuilder.WriteString("\n\n") + } + + contextBuilder.WriteString("Instructions: ") + contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ") + contextBuilder.WriteString("Extract only the most relevant information. ") + contextBuilder.WriteString("If no relevant information is found, state that clearly. ") + contextBuilder.WriteString("Cite sources by filename when relevant. ") + contextBuilder.WriteString("Do not include unnecessary preamble or explanations.") + + synthesisPrompt := contextBuilder.String() + + emb, err := r.LineToVector(synthesisPrompt) + if err != nil { + r.logger.Error("failed to embed synthesis prompt", "error", err) + return "", err + } + + embResp := &models.EmbeddingResp{ + Embedding: emb, + Index: 0, + } + + topResults, err := r.SearchEmb(embResp) + if err != nil { + r.logger.Error("failed to search for synthesis context", "error", err) + return "", err + } + + if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt { + return topResults[0].RawText, nil + } + + var finalAnswer strings.Builder + finalAnswer.WriteString("Based on the retrieved context:\n\n") + + for i, row := range results { + if i >= 5 { + break + } + finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))) + } + + return finalAnswer.String(), nil +} + +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { + refined := r.RefineQuery(query) + variations := r.GenerateQueryVariations(refined) + + allResults := make([]models.VectorRow, 0) + seen := make(map[string]bool) + + for _, q := range variations { + emb, err := r.LineToVector(q) + if err != nil { + r.logger.Error("failed to embed query variation", "error", err, "query", q) + continue + } + + embResp := &models.EmbeddingResp{ + Embedding: emb, + Index: 0, + } + + results, err := r.SearchEmb(embResp) + if err != nil { + r.logger.Error("failed to search embeddings", "error", err, "query", q) + continue + } + + for _, row := range results { + if !seen[row.Slug] { + seen[row.Slug] = true + allResults = append(allResults, row) + } + } + } + + reranked := r.RerankResults(allResults, query) + + if len(reranked) > limit { + reranked = reranked[:limit] + } + + return reranked, nil +} + +var ( + ragInstance *RAG + ragOnce sync.Once +) + +func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error { + ragOnce.Do(func() { + if c == nil || l == nil || s == nil { + return + } + ragInstance = New(l, s, c) + }) + return nil +} + +func GetInstance() *RAG { + return ragInstance +} diff --git a/tools.go b/tools.go index c397137..5cd2770 100644 --- a/tools.go +++ b/tools.go @@ -16,6 +16,7 @@ import ( "sync" "time" + "gf-lt/rag" "github.com/GrailFinder/searchagent/searcher" ) @@ -58,9 +59,9 @@ Your current tools: "when_to_use": "when asked to search the web for information; returns clean summary without html,css and other web elements; limit is optional (default 3)" }, { -"name":"websearch_raw", +"name":"rag_search", "args": ["query", "limit"], -"when_to_use": "when asked to search the web for information; returns raw data as is without processing; limit is optional (default 3)" +"when_to_use": "when asked to search the local document database for information; performs query refinement, semantic search, reranking, and synthesis; returns clean summary with sources; limit is optional (default 3)" }, { "name":"read_url", @@ -146,6 +147,7 @@ under the topic: Adam's number is stored: After that you are free to respond to the user. ` webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.` + ragSearchSysPrompt = `Synthesize the document search results, extracting key information and presenting a concise answer. Provide sources and document IDs where relevant.` readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.` summarySysPrompt = `Please provide a concise summary of the following conversation. Focus on key points, decisions, and actions. Provide only the summary, no additional commentary.` basicCard = &models.CharCard{ @@ -170,6 +172,10 @@ func init() { panic("failed to init seachagent; error: " + err.Error()) } WebSearcher = sa + + if err := rag.Init(cfg, logger, store); err != nil { + logger.Warn("failed to init rag; rag_search tool will not be available", "error", err) + } } // getWebAgentClient returns a singleton AgentClient for web agents. @@ -196,6 +202,8 @@ func getWebAgentClient() *agent.AgentClient { func registerWebAgents() { webAgentsOnce.Do(func() { client := getWebAgentClient() + // Register rag_search agent + agent.Register("rag_search", agent.NewWebAgentB(client, ragSearchSysPrompt)) // Register websearch agent agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt)) // Register read_url agent @@ -239,6 +247,48 @@ func websearch(args map[string]string) []byte { return data } +// rag search (searches local document database) +func ragsearch(args map[string]string) []byte { + query, ok := args["query"] + if !ok || query == "" { + msg := "query not provided to rag_search tool" + logger.Error(msg) + return []byte(msg) + } + limitS, ok := args["limit"] + if !ok || limitS == "" { + limitS = "3" + } + limit, err := strconv.Atoi(limitS) + if err != nil || limit == 0 { + logger.Warn("ragsearch limit; passed bad value; setting to default (3)", + "limit_arg", limitS, "error", err) + limit = 3 + } + + ragInstance := rag.GetInstance() + if ragInstance == nil { + msg := "rag not initialized; rag_search tool is not available" + logger.Error(msg) + return []byte(msg) + } + + results, err := ragInstance.Search(query, limit) + if err != nil { + msg := "rag search failed; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + data, err := json.Marshal(results) + if err != nil { + msg := "failed to marshal rag search result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return data +} + // web search raw (returns raw data without processing) func websearchRaw(args map[string]string) []byte { // make http request return bytes @@ -997,6 +1047,7 @@ var fnMap = map[string]fnSig{ "recall": recall, "recall_topics": recallTopics, "memorise": memorise, + "rag_search": ragsearch, "websearch": websearch, "websearch_raw": websearchRaw, "read_url": readURL, @@ -1033,6 +1084,28 @@ func callToolWithAgent(name string, args map[string]string) []byte { // openai style def var baseTools = []models.Tool{ + // rag_search + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "rag_search", + Description: "Search local document database given query, limit of sources (default 3). Performs query refinement, semantic search, reranking, and synthesis.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"query", "limit"}, + Properties: map[string]models.ToolArgProps{ + "query": models.ToolArgProps{ + Type: "string", + Description: "search query", + }, + "limit": models.ToolArgProps{ + Type: "string", + Description: "limit of the document results", + }, + }, + }, + }, + }, // websearch models.Tool{ Type: "function",