Chore: linter complaints
This commit is contained in:
@@ -140,7 +140,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
ag.log.Error("failed to read request body", "error", err)
|
ag.log.Error("failed to read request body", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
|
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("failed to create request", "error", err)
|
ag.log.Error("failed to create request", "error", err)
|
||||||
@@ -150,22 +149,18 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
req.Header.Add("Authorization", "Bearer "+ag.getToken())
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
|
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
|
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
responseBytes, err := io.ReadAll(resp.Body)
|
responseBytes, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ag.log.Error("failed to read response", "error", err)
|
ag.log.Error("failed to read response", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
|
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
|
||||||
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
|
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
|
||||||
@@ -178,7 +173,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
|
|||||||
// Return raw response as fallback
|
// Return raw response as fallback
|
||||||
return responseBytes, nil
|
return responseBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte(text), nil
|
return []byte(text), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ func startModelColorUpdater() {
|
|||||||
|
|
||||||
// Initial check
|
// Initial check
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
updateCachedModelColor()
|
updateCachedModelColor()
|
||||||
}
|
}
|
||||||
|
|||||||
3
llm.go
3
llm.go
@@ -216,13 +216,11 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
|
|||||||
logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data))
|
logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data))
|
||||||
return &models.TextChunk{Finished: true}, nil
|
return &models.TextChunk{Finished: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1]
|
lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1]
|
||||||
resp := &models.TextChunk{
|
resp := &models.TextChunk{
|
||||||
Chunk: lastChoice.Delta.Content,
|
Chunk: lastChoice.Delta.Content,
|
||||||
Reasoning: lastChoice.Delta.ReasoningContent,
|
Reasoning: lastChoice.Delta.ReasoningContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool calls in all choices, not just the last one
|
// Check for tool calls in all choices, not just the last one
|
||||||
for _, choice := range llmchunk.Choices {
|
for _, choice := range llmchunk.Choices {
|
||||||
if len(choice.Delta.ToolCalls) > 0 {
|
if len(choice.Delta.ToolCalls) > 0 {
|
||||||
@@ -237,7 +235,6 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
|
|||||||
break // Process only the first tool call
|
break // Process only the first tool call
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastChoice.FinishReason == "stop" {
|
if lastChoice.FinishReason == "stop" {
|
||||||
if resp.Chunk != "" {
|
if resp.Chunk != "" {
|
||||||
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
|
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
|
||||||
|
|||||||
@@ -400,7 +400,6 @@ func (m *RoleMsg) AddTextPart(text string) {
|
|||||||
}
|
}
|
||||||
m.hasContentParts = true
|
m.hasContentParts = true
|
||||||
}
|
}
|
||||||
|
|
||||||
textPart := TextContentPart{Type: "text", Text: text}
|
textPart := TextContentPart{Type: "text", Text: text}
|
||||||
m.ContentParts = append(m.ContentParts, textPart)
|
m.ContentParts = append(m.ContentParts, textPart)
|
||||||
}
|
}
|
||||||
@@ -416,7 +415,6 @@ func (m *RoleMsg) AddImagePart(imageURL, imagePath string) {
|
|||||||
}
|
}
|
||||||
m.hasContentParts = true
|
m.hasContentParts = true
|
||||||
}
|
}
|
||||||
|
|
||||||
imagePart := ImageContentPart{
|
imagePart := ImageContentPart{
|
||||||
Type: "image_url",
|
Type: "image_url",
|
||||||
Path: imagePath, // Store the original file path
|
Path: imagePath, // Store the original file path
|
||||||
|
|||||||
@@ -410,38 +410,30 @@ func updateWidgetColors(theme *tview.Theme) {
|
|||||||
fgColor := theme.PrimaryTextColor
|
fgColor := theme.PrimaryTextColor
|
||||||
borderColor := theme.BorderColor
|
borderColor := theme.BorderColor
|
||||||
titleColor := theme.TitleColor
|
titleColor := theme.TitleColor
|
||||||
|
|
||||||
textView.SetBackgroundColor(bgColor)
|
textView.SetBackgroundColor(bgColor)
|
||||||
textView.SetTextColor(fgColor)
|
textView.SetTextColor(fgColor)
|
||||||
textView.SetBorderColor(borderColor)
|
textView.SetBorderColor(borderColor)
|
||||||
textView.SetTitleColor(titleColor)
|
textView.SetTitleColor(titleColor)
|
||||||
|
|
||||||
textArea.SetBackgroundColor(bgColor)
|
textArea.SetBackgroundColor(bgColor)
|
||||||
textArea.SetBorderColor(borderColor)
|
textArea.SetBorderColor(borderColor)
|
||||||
textArea.SetTitleColor(titleColor)
|
textArea.SetTitleColor(titleColor)
|
||||||
textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
|
|
||||||
textArea.SetText(textArea.GetText(), true)
|
textArea.SetText(textArea.GetText(), true)
|
||||||
|
|
||||||
editArea.SetBackgroundColor(bgColor)
|
editArea.SetBackgroundColor(bgColor)
|
||||||
editArea.SetBorderColor(borderColor)
|
editArea.SetBorderColor(borderColor)
|
||||||
editArea.SetTitleColor(titleColor)
|
editArea.SetTitleColor(titleColor)
|
||||||
editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
|
||||||
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
|
|
||||||
editArea.SetText(editArea.GetText(), true)
|
editArea.SetText(editArea.GetText(), true)
|
||||||
|
|
||||||
statusLineWidget.SetBackgroundColor(bgColor)
|
statusLineWidget.SetBackgroundColor(bgColor)
|
||||||
statusLineWidget.SetTextColor(fgColor)
|
statusLineWidget.SetTextColor(fgColor)
|
||||||
statusLineWidget.SetBorderColor(borderColor)
|
statusLineWidget.SetBorderColor(borderColor)
|
||||||
statusLineWidget.SetTitleColor(titleColor)
|
statusLineWidget.SetTitleColor(titleColor)
|
||||||
|
|
||||||
helpView.SetBackgroundColor(bgColor)
|
helpView.SetBackgroundColor(bgColor)
|
||||||
helpView.SetTextColor(fgColor)
|
helpView.SetTextColor(fgColor)
|
||||||
helpView.SetBorderColor(borderColor)
|
helpView.SetBorderColor(borderColor)
|
||||||
helpView.SetTitleColor(titleColor)
|
helpView.SetTitleColor(titleColor)
|
||||||
|
|
||||||
searchField.SetBackgroundColor(bgColor)
|
searchField.SetBackgroundColor(bgColor)
|
||||||
searchField.SetBorderColor(borderColor)
|
searchField.SetBorderColor(borderColor)
|
||||||
searchField.SetTitleColor(titleColor)
|
searchField.SetTitleColor(titleColor)
|
||||||
@@ -468,7 +460,6 @@ func showColorschemeSelectionPopup() {
|
|||||||
schemeListWidget := tview.NewList().ShowSecondaryText(false).
|
schemeListWidget := tview.NewList().ShowSecondaryText(false).
|
||||||
SetSelectedBackgroundColor(tcell.ColorGray)
|
SetSelectedBackgroundColor(tcell.ColorGray)
|
||||||
schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true)
|
schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true)
|
||||||
|
|
||||||
currentScheme := "default"
|
currentScheme := "default"
|
||||||
for name := range colorschemes {
|
for name := range colorschemes {
|
||||||
if tview.Styles == colorschemes[name] {
|
if tview.Styles == colorschemes[name] {
|
||||||
|
|||||||
@@ -131,7 +131,6 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
|||||||
}
|
}
|
||||||
embeddings[data.Index] = data.Embedding
|
embeddings[data.Index] = data.Embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
return embeddings, nil
|
return embeddings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -95,9 +95,7 @@ func extractTextFromEpub(fpath string) (string, error) {
|
|||||||
return "", fmt.Errorf("failed to open epub: %w", err)
|
return "", fmt.Errorf("failed to open epub: %w", err)
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
for _, f := range r.File {
|
for _, f := range r.File {
|
||||||
ext := strings.ToLower(path.Ext(f.Name))
|
ext := strings.ToLower(path.Ext(f.Name))
|
||||||
if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" {
|
if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" {
|
||||||
@@ -129,7 +127,6 @@ func extractTextFromEpub(fpath string) (string, error) {
|
|||||||
sb.WriteString(stripHTML(string(buf)))
|
sb.WriteString(stripHTML(string(buf)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sb.Len() == 0 {
|
if sb.Len() == 0 {
|
||||||
return "", errors.New("no content extracted from epub")
|
return "", errors.New("no content extracted from epub")
|
||||||
}
|
}
|
||||||
|
|||||||
55
rag/rag.go
55
rag/rag.go
@@ -36,7 +36,6 @@ type RAG struct {
|
|||||||
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
|
||||||
// Initialize with API embedder by default, could be configurable later
|
// Initialize with API embedder by default, could be configurable later
|
||||||
embedder := NewAPIEmbedder(l, cfg)
|
embedder := NewAPIEmbedder(l, cfg)
|
||||||
|
|
||||||
rag := &RAG{
|
rag := &RAG{
|
||||||
logger: l,
|
logger: l,
|
||||||
store: s,
|
store: s,
|
||||||
@@ -205,29 +204,22 @@ var (
|
|||||||
func (r *RAG) RefineQuery(query string) string {
|
func (r *RAG) RefineQuery(query string) string {
|
||||||
original := query
|
original := query
|
||||||
query = strings.TrimSpace(query)
|
query = strings.TrimSpace(query)
|
||||||
|
|
||||||
if len(query) == 0 {
|
if len(query) == 0 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(query) <= 3 {
|
if len(query) <= 3 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
query = strings.ToLower(query)
|
query = strings.ToLower(query)
|
||||||
|
|
||||||
for _, stopWord := range stopWords {
|
for _, stopWord := range stopWords {
|
||||||
wordPattern := `\b` + stopWord + `\b`
|
wordPattern := `\b` + stopWord + `\b`
|
||||||
re := regexp.MustCompile(wordPattern)
|
re := regexp.MustCompile(wordPattern)
|
||||||
query = re.ReplaceAllString(query, "")
|
query = re.ReplaceAllString(query, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
query = strings.TrimSpace(query)
|
query = strings.TrimSpace(query)
|
||||||
|
|
||||||
if len(query) < 5 {
|
if len(query) < 5 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
if queryRefinementPattern.MatchString(original) {
|
if queryRefinementPattern.MatchString(original) {
|
||||||
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
|
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
|
||||||
cleaned = strings.TrimSpace(cleaned)
|
cleaned = strings.TrimSpace(cleaned)
|
||||||
@@ -235,23 +227,18 @@ func (r *RAG) RefineQuery(query string) string {
|
|||||||
return cleaned
|
return cleaned
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query = r.extractImportantPhrases(query)
|
query = r.extractImportantPhrases(query)
|
||||||
|
|
||||||
if len(query) < 5 {
|
if len(query) < 5 {
|
||||||
return original
|
return original
|
||||||
}
|
}
|
||||||
|
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) extractImportantPhrases(query string) string {
|
func (r *RAG) extractImportantPhrases(query string) string {
|
||||||
words := strings.Fields(query)
|
words := strings.Fields(query)
|
||||||
|
|
||||||
var important []string
|
var important []string
|
||||||
for _, word := range words {
|
for _, word := range words {
|
||||||
word = strings.Trim(word, ".,!?;:'\"()[]{}")
|
word = strings.Trim(word, ".,!?;:'\"()[]{}")
|
||||||
|
|
||||||
isImportant := false
|
isImportant := false
|
||||||
for _, kw := range importantKeywords {
|
for _, kw := range importantKeywords {
|
||||||
if strings.Contains(strings.ToLower(word), kw) {
|
if strings.Contains(strings.ToLower(word), kw) {
|
||||||
@@ -259,45 +246,37 @@ func (r *RAG) extractImportantPhrases(query string) string {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isImportant || len(word) > 3 {
|
if isImportant || len(word) > 3 {
|
||||||
important = append(important, word)
|
important = append(important, word)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(important) == 0 {
|
if len(important) == 0 {
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(important, " ")
|
return strings.Join(important, " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RAG) GenerateQueryVariations(query string) []string {
|
func (r *RAG) GenerateQueryVariations(query string) []string {
|
||||||
variations := []string{query}
|
variations := []string{query}
|
||||||
|
|
||||||
if len(query) < 5 {
|
if len(query) < 5 {
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Fields(query)
|
parts := strings.Fields(query)
|
||||||
if len(parts) == 0 {
|
if len(parts) == 0 {
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
trimmed := strings.Join(parts[:len(parts)-1], " ")
|
trimmed := strings.Join(parts[:len(parts)-1], " ")
|
||||||
if len(trimmed) >= 5 {
|
if len(trimmed) >= 5 {
|
||||||
variations = append(variations, trimmed)
|
variations = append(variations, trimmed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
trimmed := strings.Join(parts[1:], " ")
|
trimmed := strings.Join(parts[1:], " ")
|
||||||
if len(trimmed) >= 5 {
|
if len(trimmed) >= 5 {
|
||||||
variations = append(variations, trimmed)
|
variations = append(variations, trimmed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasSuffix(query, " explanation") {
|
if !strings.HasSuffix(query, " explanation") {
|
||||||
variations = append(variations, query+" explanation")
|
variations = append(variations, query+" explanation")
|
||||||
}
|
}
|
||||||
@@ -310,7 +289,6 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
|
|||||||
if !strings.HasSuffix(query, " summary") {
|
if !strings.HasSuffix(query, " summary") {
|
||||||
variations = append(variations, query+" summary")
|
variations = append(variations, query+" summary")
|
||||||
}
|
}
|
||||||
|
|
||||||
return variations
|
return variations
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,21 +297,16 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
|||||||
row models.VectorRow
|
row models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
scored := make([]scoredResult, 0, len(results))
|
scored := make([]scoredResult, 0, len(results))
|
||||||
|
|
||||||
for i := range results {
|
for i := range results {
|
||||||
row := results[i]
|
row := results[i]
|
||||||
|
|
||||||
score := float32(0)
|
score := float32(0)
|
||||||
|
|
||||||
rawTextLower := strings.ToLower(row.RawText)
|
rawTextLower := strings.ToLower(row.RawText)
|
||||||
queryLower := strings.ToLower(query)
|
queryLower := strings.ToLower(query)
|
||||||
|
|
||||||
if strings.Contains(rawTextLower, queryLower) {
|
if strings.Contains(rawTextLower, queryLower) {
|
||||||
score += 10
|
score += 10
|
||||||
}
|
}
|
||||||
|
|
||||||
queryWords := strings.Fields(queryLower)
|
queryWords := strings.Fields(queryLower)
|
||||||
matchCount := 0
|
matchCount := 0
|
||||||
for _, word := range queryWords {
|
for _, word := range queryWords {
|
||||||
@@ -344,34 +317,26 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
|
|||||||
if len(queryWords) > 0 {
|
if len(queryWords) > 0 {
|
||||||
score += float32(matchCount) / float32(len(queryWords)) * 5
|
score += float32(matchCount) / float32(len(queryWords)) * 5
|
||||||
}
|
}
|
||||||
|
|
||||||
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
|
||||||
score += 3
|
score += 3
|
||||||
}
|
}
|
||||||
|
|
||||||
distance := row.Distance - score/100
|
distance := row.Distance - score/100
|
||||||
|
|
||||||
scored = append(scored, scoredResult{row: row, distance: distance})
|
scored = append(scored, scoredResult{row: row, distance: distance})
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(scored, func(i, j int) bool {
|
sort.Slice(scored, func(i, j int) bool {
|
||||||
return scored[i].distance < scored[j].distance
|
return scored[i].distance < scored[j].distance
|
||||||
})
|
})
|
||||||
|
|
||||||
unique := make([]models.VectorRow, 0)
|
unique := make([]models.VectorRow, 0)
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
for i := range scored {
|
for i := range scored {
|
||||||
if !seen[scored[i].row.Slug] {
|
if !seen[scored[i].row.Slug] {
|
||||||
seen[scored[i].row.Slug] = true
|
seen[scored[i].row.Slug] = true
|
||||||
unique = append(unique, scored[i].row)
|
unique = append(unique, scored[i].row)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(unique) > 10 {
|
if len(unique) > 10 {
|
||||||
unique = unique[:10]
|
unique = unique[:10]
|
||||||
}
|
}
|
||||||
|
|
||||||
return unique
|
return unique
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,58 +344,47 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
|
|||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
return "No relevant information found in the vector database.", nil
|
return "No relevant information found in the vector database.", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var contextBuilder strings.Builder
|
var contextBuilder strings.Builder
|
||||||
contextBuilder.WriteString("User Query: ")
|
contextBuilder.WriteString("User Query: ")
|
||||||
contextBuilder.WriteString(query)
|
contextBuilder.WriteString(query)
|
||||||
contextBuilder.WriteString("\n\nRetrieved Context:\n")
|
contextBuilder.WriteString("\n\nRetrieved Context:\n")
|
||||||
|
|
||||||
for i, row := range results {
|
for i, row := range results {
|
||||||
contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName))
|
fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName)
|
||||||
contextBuilder.WriteString(row.RawText)
|
contextBuilder.WriteString(row.RawText)
|
||||||
contextBuilder.WriteString("\n\n")
|
contextBuilder.WriteString("\n\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
contextBuilder.WriteString("Instructions: ")
|
contextBuilder.WriteString("Instructions: ")
|
||||||
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
|
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
|
||||||
contextBuilder.WriteString("Extract only the most relevant information. ")
|
contextBuilder.WriteString("Extract only the most relevant information. ")
|
||||||
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
|
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
|
||||||
contextBuilder.WriteString("Cite sources by filename when relevant. ")
|
contextBuilder.WriteString("Cite sources by filename when relevant. ")
|
||||||
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
|
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
|
||||||
|
|
||||||
synthesisPrompt := contextBuilder.String()
|
synthesisPrompt := contextBuilder.String()
|
||||||
|
|
||||||
emb, err := r.LineToVector(synthesisPrompt)
|
emb, err := r.LineToVector(synthesisPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to embed synthesis prompt", "error", err)
|
r.logger.Error("failed to embed synthesis prompt", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
embResp := &models.EmbeddingResp{
|
embResp := &models.EmbeddingResp{
|
||||||
Embedding: emb,
|
Embedding: emb,
|
||||||
Index: 0,
|
Index: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
topResults, err := r.SearchEmb(embResp)
|
topResults, err := r.SearchEmb(embResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.Error("failed to search for synthesis context", "error", err)
|
r.logger.Error("failed to search for synthesis context", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
|
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
|
||||||
return topResults[0].RawText, nil
|
return topResults[0].RawText, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var finalAnswer strings.Builder
|
var finalAnswer strings.Builder
|
||||||
finalAnswer.WriteString("Based on the retrieved context:\n\n")
|
finalAnswer.WriteString("Based on the retrieved context:\n\n")
|
||||||
|
|
||||||
for i, row := range results {
|
for i, row := range results {
|
||||||
if i >= 5 {
|
if i >= 5 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)))
|
fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))
|
||||||
}
|
}
|
||||||
|
|
||||||
return finalAnswer.String(), nil
|
return finalAnswer.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -444,10 +398,8 @@ func truncateString(s string, maxLen int) string {
|
|||||||
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
||||||
refined := r.RefineQuery(query)
|
refined := r.RefineQuery(query)
|
||||||
variations := r.GenerateQueryVariations(refined)
|
variations := r.GenerateQueryVariations(refined)
|
||||||
|
|
||||||
allResults := make([]models.VectorRow, 0)
|
allResults := make([]models.VectorRow, 0)
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
for _, q := range variations {
|
for _, q := range variations {
|
||||||
emb, err := r.LineToVector(q)
|
emb, err := r.LineToVector(q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -473,13 +425,10 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reranked := r.RerankResults(allResults, query)
|
reranked := r.RerankResults(allResults, query)
|
||||||
|
|
||||||
if len(reranked) > limit {
|
if len(reranked) > limit {
|
||||||
reranked = reranked[:limit]
|
reranked = reranked[:limit]
|
||||||
}
|
}
|
||||||
|
|
||||||
return reranked, nil
|
return reranked, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// SerializeVector converts []float32 to binary blob
|
// SerializeVector converts []float32 to binary blob
|
||||||
func SerializeVector(vec []float32) []byte {
|
func SerializeVector(vec []float32) []byte {
|
||||||
buf := make([]byte, len(vec)*4) // 4 bytes per float32
|
buf := make([]byte, len(vec)*4) // 4 bytes per float32
|
||||||
@@ -66,17 +65,14 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
|
|||||||
|
|
||||||
// Serialize the embeddings to binary
|
// Serialize the embeddings to binary
|
||||||
serializedEmbeddings := SerializeVector(row.Embeddings)
|
serializedEmbeddings := SerializeVector(row.Embeddings)
|
||||||
|
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
|
||||||
tableName,
|
tableName,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
|
||||||
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,20 +82,18 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
|
|||||||
|
|
||||||
// Check if we support this embedding size
|
// Check if we support this embedding size
|
||||||
supportedSizes := map[int]bool{
|
supportedSizes := map[int]bool{
|
||||||
384: true,
|
384: true,
|
||||||
768: true,
|
768: true,
|
||||||
1024: true,
|
1024: true,
|
||||||
1536: true,
|
1536: true,
|
||||||
2048: true,
|
2048: true,
|
||||||
3072: true,
|
3072: true,
|
||||||
4096: true,
|
4096: true,
|
||||||
5120: true,
|
5120: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if supportedSizes[size] {
|
if supportedSizes[size] {
|
||||||
return fmt.Sprintf("embeddings_%d", size), nil
|
return fmt.Sprintf("embeddings_%d", size), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("no table for embedding size of %d", size)
|
return "", fmt.Errorf("no table for embedding size of %d", size)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,9 +120,7 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
var topResults []SearchResult
|
var topResults []SearchResult
|
||||||
|
|
||||||
// Process vectors one by one to avoid loading everything into memory
|
// Process vectors one by one to avoid loading everything into memory
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
@@ -176,14 +168,12 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
|
|||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
results = append(results, result.vector)
|
results = append(results, result.vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFiles returns a list of all loaded files
|
// ListFiles returns a list of all loaded files
|
||||||
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
func (vs *VectorStorage) ListFiles() ([]string, error) {
|
||||||
fileLists := make([][]string, 0)
|
fileLists := make([][]string, 0)
|
||||||
|
|
||||||
// Query all supported tables and combine results
|
// Query all supported tables and combine results
|
||||||
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
for _, size := range embeddingSizes {
|
for _, size := range embeddingSizes {
|
||||||
@@ -219,14 +209,12 @@ func (vs *VectorStorage) ListFiles() ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return allFiles, nil
|
return allFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
// RemoveEmbByFileName removes all embeddings associated with a specific filename
|
||||||
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
|
||||||
for _, size := range embeddingSizes {
|
for _, size := range embeddingSizes {
|
||||||
table := fmt.Sprintf("embeddings_%d", size)
|
table := fmt.Sprintf("embeddings_%d", size)
|
||||||
@@ -235,11 +223,9 @@ func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
|
|||||||
errors = append(errors, err.Error())
|
errors = append(errors, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
|
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,18 +234,15 @@ func cosineSimilarity(a, b []float32) float32 {
|
|||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var dotProduct, normA, normB float32
|
var dotProduct, normA, normB float32
|
||||||
for i := 0; i < len(a); i++ {
|
for i := 0; i < len(a); i++ {
|
||||||
dotProduct += a[i] * b[i]
|
dotProduct += a[i] * b[i]
|
||||||
normA += a[i] * a[i]
|
normA += a[i] * a[i]
|
||||||
normB += b[i] * b[i]
|
normB += b[i] * b[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if normA == 0 || normB == 0 {
|
if normA == 0 || normB == 0 {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,4 +258,3 @@ func sqrt(f float32) float32 {
|
|||||||
}
|
}
|
||||||
return guess
|
return guess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,6 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
p := ProviderSQL{db: db, logger: logger}
|
p := ProviderSQL{db: db, logger: logger}
|
||||||
|
|
||||||
p.Migrate()
|
p.Migrate()
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,12 +73,9 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
serializedEmbeddings := SerializeVector(row.Embeddings)
|
serializedEmbeddings := SerializeVector(row.Embeddings)
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
|
query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
|
||||||
_, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
|
_, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,27 +84,22 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
|
||||||
rows, err := p.db.Query(querySQL)
|
rows, err := p.db.Query(querySQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
type SearchResult struct {
|
type SearchResult struct {
|
||||||
vector models.VectorRow
|
vector models.VectorRow
|
||||||
distance float32
|
distance float32
|
||||||
}
|
}
|
||||||
|
|
||||||
var topResults []SearchResult
|
var topResults []SearchResult
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
embeddingsBlob []byte
|
embeddingsBlob []byte
|
||||||
slug, rawText, fileName string
|
slug, rawText, fileName string
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
|
if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -152,7 +144,6 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
|
|||||||
result.vector.Distance = result.distance
|
result.vector.Distance = result.distance
|
||||||
results[i] = result.vector
|
results[i] = result.vector
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,18 +152,15 @@ func cosineSimilarity(a, b []float32) float32 {
|
|||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var dotProduct, normA, normB float32
|
var dotProduct, normA, normB float32
|
||||||
for i := 0; i < len(a); i++ {
|
for i := 0; i < len(a); i++ {
|
||||||
dotProduct += a[i] * b[i]
|
dotProduct += a[i] * b[i]
|
||||||
normA += a[i] * a[i]
|
normA += a[i] * a[i]
|
||||||
normB += b[i] * b[i]
|
normB += b[i] * b[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if normA == 0 || normB == 0 {
|
if normA == 0 || normB == 0 {
|
||||||
return 0.0
|
return 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,13 +217,11 @@ func (p ProviderSQL) ListFiles() ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return allFiles, nil
|
return allFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
tableNames := []string{
|
tableNames := []string{
|
||||||
"embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
|
"embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
|
||||||
"embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
|
"embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
|
||||||
@@ -246,10 +232,8 @@ func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
|
|||||||
errors = append(errors, err.Error())
|
errors = append(errors, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
return fmt.Errorf("errors occurred: %v", errors)
|
return fmt.Errorf("errors occurred: %v", errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -287,7 +287,6 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows := len(ragFiles)
|
rows := len(ragFiles)
|
||||||
cols := 4 // File Name | Preview | Action | Delete
|
cols := 4 // File Name | Preview | Action | Delete
|
||||||
fileTable := tview.NewTable().
|
fileTable := tview.NewTable().
|
||||||
|
|||||||
Reference in New Issue
Block a user