Chore: linter complaints

This commit is contained in:
Grail Finder
2026-02-25 20:06:56 +03:00
parent 4f07994bdc
commit 888c9fec65
12 changed files with 11 additions and 123 deletions

View File

@@ -140,7 +140,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
ag.log.Error("failed to read request body", "error", err) ag.log.Error("failed to read request body", "error", err)
return nil, err return nil, err
} }
req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes)) req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
if err != nil { if err != nil {
ag.log.Error("failed to create request", "error", err) ag.log.Error("failed to create request", "error", err)
@@ -150,22 +149,18 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+ag.getToken()) req.Header.Add("Authorization", "Bearer "+ag.getToken())
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")
ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)])) ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI) ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
responseBytes, err := io.ReadAll(resp.Body) responseBytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
ag.log.Error("failed to read response", "error", err) ag.log.Error("failed to read response", "error", err)
return nil, err return nil, err
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)])) ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)])) return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
@@ -178,7 +173,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
// Return raw response as fallback // Return raw response as fallback
return responseBytes, nil return responseBytes, nil
} }
return []byte(text), nil return []byte(text), nil
} }

View File

@@ -32,7 +32,6 @@ func startModelColorUpdater() {
// Initial check // Initial check
updateCachedModelColor() updateCachedModelColor()
for range ticker.C { for range ticker.C {
updateCachedModelColor() updateCachedModelColor()
} }

3
llm.go
View File

@@ -216,13 +216,11 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data)) logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data))
return &models.TextChunk{Finished: true}, nil return &models.TextChunk{Finished: true}, nil
} }
lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1] lastChoice := llmchunk.Choices[len(llmchunk.Choices)-1]
resp := &models.TextChunk{ resp := &models.TextChunk{
Chunk: lastChoice.Delta.Content, Chunk: lastChoice.Delta.Content,
Reasoning: lastChoice.Delta.ReasoningContent, Reasoning: lastChoice.Delta.ReasoningContent,
} }
// Check for tool calls in all choices, not just the last one // Check for tool calls in all choices, not just the last one
for _, choice := range llmchunk.Choices { for _, choice := range llmchunk.Choices {
if len(choice.Delta.ToolCalls) > 0 { if len(choice.Delta.ToolCalls) > 0 {
@@ -237,7 +235,6 @@ func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) {
break // Process only the first tool call break // Process only the first tool call
} }
} }
if lastChoice.FinishReason == "stop" { if lastChoice.FinishReason == "stop" {
if resp.Chunk != "" { if resp.Chunk != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk) logger.Error("text inside of finish llmchunk", "chunk", llmchunk)

View File

@@ -400,7 +400,6 @@ func (m *RoleMsg) AddTextPart(text string) {
} }
m.hasContentParts = true m.hasContentParts = true
} }
textPart := TextContentPart{Type: "text", Text: text} textPart := TextContentPart{Type: "text", Text: text}
m.ContentParts = append(m.ContentParts, textPart) m.ContentParts = append(m.ContentParts, textPart)
} }
@@ -416,7 +415,6 @@ func (m *RoleMsg) AddImagePart(imageURL, imagePath string) {
} }
m.hasContentParts = true m.hasContentParts = true
} }
imagePart := ImageContentPart{ imagePart := ImageContentPart{
Type: "image_url", Type: "image_url",
Path: imagePath, // Store the original file path Path: imagePath, // Store the original file path

View File

@@ -410,38 +410,30 @@ func updateWidgetColors(theme *tview.Theme) {
fgColor := theme.PrimaryTextColor fgColor := theme.PrimaryTextColor
borderColor := theme.BorderColor borderColor := theme.BorderColor
titleColor := theme.TitleColor titleColor := theme.TitleColor
textView.SetBackgroundColor(bgColor) textView.SetBackgroundColor(bgColor)
textView.SetTextColor(fgColor) textView.SetTextColor(fgColor)
textView.SetBorderColor(borderColor) textView.SetBorderColor(borderColor)
textView.SetTitleColor(titleColor) textView.SetTitleColor(titleColor)
textArea.SetBackgroundColor(bgColor) textArea.SetBackgroundColor(bgColor)
textArea.SetBorderColor(borderColor) textArea.SetBorderColor(borderColor)
textArea.SetTitleColor(titleColor) textArea.SetTitleColor(titleColor)
textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor)) textArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor)) textArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
textArea.SetText(textArea.GetText(), true) textArea.SetText(textArea.GetText(), true)
editArea.SetBackgroundColor(bgColor) editArea.SetBackgroundColor(bgColor)
editArea.SetBorderColor(borderColor) editArea.SetBorderColor(borderColor)
editArea.SetTitleColor(titleColor) editArea.SetTitleColor(titleColor)
editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor)) editArea.SetTextStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor)) editArea.SetPlaceholderStyle(tcell.StyleDefault.Background(bgColor).Foreground(fgColor))
// Force textarea refresh by restoring text (SetTextStyle doesn't trigger redraw)
editArea.SetText(editArea.GetText(), true) editArea.SetText(editArea.GetText(), true)
statusLineWidget.SetBackgroundColor(bgColor) statusLineWidget.SetBackgroundColor(bgColor)
statusLineWidget.SetTextColor(fgColor) statusLineWidget.SetTextColor(fgColor)
statusLineWidget.SetBorderColor(borderColor) statusLineWidget.SetBorderColor(borderColor)
statusLineWidget.SetTitleColor(titleColor) statusLineWidget.SetTitleColor(titleColor)
helpView.SetBackgroundColor(bgColor) helpView.SetBackgroundColor(bgColor)
helpView.SetTextColor(fgColor) helpView.SetTextColor(fgColor)
helpView.SetBorderColor(borderColor) helpView.SetBorderColor(borderColor)
helpView.SetTitleColor(titleColor) helpView.SetTitleColor(titleColor)
searchField.SetBackgroundColor(bgColor) searchField.SetBackgroundColor(bgColor)
searchField.SetBorderColor(borderColor) searchField.SetBorderColor(borderColor)
searchField.SetTitleColor(titleColor) searchField.SetTitleColor(titleColor)
@@ -468,7 +460,6 @@ func showColorschemeSelectionPopup() {
schemeListWidget := tview.NewList().ShowSecondaryText(false). schemeListWidget := tview.NewList().ShowSecondaryText(false).
SetSelectedBackgroundColor(tcell.ColorGray) SetSelectedBackgroundColor(tcell.ColorGray)
schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true) schemeListWidget.SetTitle("Select Colorscheme").SetBorder(true)
currentScheme := "default" currentScheme := "default"
for name := range colorschemes { for name := range colorschemes {
if tview.Styles == colorschemes[name] { if tview.Styles == colorschemes[name] {

View File

@@ -131,7 +131,6 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
} }
embeddings[data.Index] = data.Embedding embeddings[data.Index] = data.Embedding
} }
return embeddings, nil return embeddings, nil
} }

View File

@@ -95,9 +95,7 @@ func extractTextFromEpub(fpath string) (string, error) {
return "", fmt.Errorf("failed to open epub: %w", err) return "", fmt.Errorf("failed to open epub: %w", err)
} }
defer r.Close() defer r.Close()
var sb strings.Builder var sb strings.Builder
for _, f := range r.File { for _, f := range r.File {
ext := strings.ToLower(path.Ext(f.Name)) ext := strings.ToLower(path.Ext(f.Name))
if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" { if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" {
@@ -129,7 +127,6 @@ func extractTextFromEpub(fpath string) (string, error) {
sb.WriteString(stripHTML(string(buf))) sb.WriteString(stripHTML(string(buf)))
} }
} }
if sb.Len() == 0 { if sb.Len() == 0 {
return "", errors.New("no content extracted from epub") return "", errors.New("no content extracted from epub")
} }

View File

@@ -36,7 +36,6 @@ type RAG struct {
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
// Initialize with API embedder by default, could be configurable later // Initialize with API embedder by default, could be configurable later
embedder := NewAPIEmbedder(l, cfg) embedder := NewAPIEmbedder(l, cfg)
rag := &RAG{ rag := &RAG{
logger: l, logger: l,
store: s, store: s,
@@ -205,29 +204,22 @@ var (
func (r *RAG) RefineQuery(query string) string { func (r *RAG) RefineQuery(query string) string {
original := query original := query
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) == 0 { if len(query) == 0 {
return original return original
} }
if len(query) <= 3 { if len(query) <= 3 {
return original return original
} }
query = strings.ToLower(query) query = strings.ToLower(query)
for _, stopWord := range stopWords { for _, stopWord := range stopWords {
wordPattern := `\b` + stopWord + `\b` wordPattern := `\b` + stopWord + `\b`
re := regexp.MustCompile(wordPattern) re := regexp.MustCompile(wordPattern)
query = re.ReplaceAllString(query, "") query = re.ReplaceAllString(query, "")
} }
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) < 5 { if len(query) < 5 {
return original return original
} }
if queryRefinementPattern.MatchString(original) { if queryRefinementPattern.MatchString(original) {
cleaned := queryRefinementPattern.ReplaceAllString(original, "") cleaned := queryRefinementPattern.ReplaceAllString(original, "")
cleaned = strings.TrimSpace(cleaned) cleaned = strings.TrimSpace(cleaned)
@@ -235,23 +227,18 @@ func (r *RAG) RefineQuery(query string) string {
return cleaned return cleaned
} }
} }
query = r.extractImportantPhrases(query) query = r.extractImportantPhrases(query)
if len(query) < 5 { if len(query) < 5 {
return original return original
} }
return query return query
} }
func (r *RAG) extractImportantPhrases(query string) string { func (r *RAG) extractImportantPhrases(query string) string {
words := strings.Fields(query) words := strings.Fields(query)
var important []string var important []string
for _, word := range words { for _, word := range words {
word = strings.Trim(word, ".,!?;:'\"()[]{}") word = strings.Trim(word, ".,!?;:'\"()[]{}")
isImportant := false isImportant := false
for _, kw := range importantKeywords { for _, kw := range importantKeywords {
if strings.Contains(strings.ToLower(word), kw) { if strings.Contains(strings.ToLower(word), kw) {
@@ -259,45 +246,37 @@ func (r *RAG) extractImportantPhrases(query string) string {
break break
} }
} }
if isImportant || len(word) > 3 { if isImportant || len(word) > 3 {
important = append(important, word) important = append(important, word)
} }
} }
if len(important) == 0 { if len(important) == 0 {
return query return query
} }
return strings.Join(important, " ") return strings.Join(important, " ")
} }
func (r *RAG) GenerateQueryVariations(query string) []string { func (r *RAG) GenerateQueryVariations(query string) []string {
variations := []string{query} variations := []string{query}
if len(query) < 5 { if len(query) < 5 {
return variations return variations
} }
parts := strings.Fields(query) parts := strings.Fields(query)
if len(parts) == 0 { if len(parts) == 0 {
return variations return variations
} }
if len(parts) >= 2 { if len(parts) >= 2 {
trimmed := strings.Join(parts[:len(parts)-1], " ") trimmed := strings.Join(parts[:len(parts)-1], " ")
if len(trimmed) >= 5 { if len(trimmed) >= 5 {
variations = append(variations, trimmed) variations = append(variations, trimmed)
} }
} }
if len(parts) >= 2 { if len(parts) >= 2 {
trimmed := strings.Join(parts[1:], " ") trimmed := strings.Join(parts[1:], " ")
if len(trimmed) >= 5 { if len(trimmed) >= 5 {
variations = append(variations, trimmed) variations = append(variations, trimmed)
} }
} }
if !strings.HasSuffix(query, " explanation") { if !strings.HasSuffix(query, " explanation") {
variations = append(variations, query+" explanation") variations = append(variations, query+" explanation")
} }
@@ -310,7 +289,6 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if !strings.HasSuffix(query, " summary") { if !strings.HasSuffix(query, " summary") {
variations = append(variations, query+" summary") variations = append(variations, query+" summary")
} }
return variations return variations
} }
@@ -319,21 +297,16 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
row models.VectorRow row models.VectorRow
distance float32 distance float32
} }
scored := make([]scoredResult, 0, len(results)) scored := make([]scoredResult, 0, len(results))
for i := range results { for i := range results {
row := results[i] row := results[i]
score := float32(0) score := float32(0)
rawTextLower := strings.ToLower(row.RawText) rawTextLower := strings.ToLower(row.RawText)
queryLower := strings.ToLower(query) queryLower := strings.ToLower(query)
if strings.Contains(rawTextLower, queryLower) { if strings.Contains(rawTextLower, queryLower) {
score += 10 score += 10
} }
queryWords := strings.Fields(queryLower) queryWords := strings.Fields(queryLower)
matchCount := 0 matchCount := 0
for _, word := range queryWords { for _, word := range queryWords {
@@ -344,34 +317,26 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
if len(queryWords) > 0 { if len(queryWords) > 0 {
score += float32(matchCount) / float32(len(queryWords)) * 5 score += float32(matchCount) / float32(len(queryWords)) * 5
} }
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
score += 3 score += 3
} }
distance := row.Distance - score/100 distance := row.Distance - score/100
scored = append(scored, scoredResult{row: row, distance: distance}) scored = append(scored, scoredResult{row: row, distance: distance})
} }
sort.Slice(scored, func(i, j int) bool { sort.Slice(scored, func(i, j int) bool {
return scored[i].distance < scored[j].distance return scored[i].distance < scored[j].distance
}) })
unique := make([]models.VectorRow, 0) unique := make([]models.VectorRow, 0)
seen := make(map[string]bool) seen := make(map[string]bool)
for i := range scored { for i := range scored {
if !seen[scored[i].row.Slug] { if !seen[scored[i].row.Slug] {
seen[scored[i].row.Slug] = true seen[scored[i].row.Slug] = true
unique = append(unique, scored[i].row) unique = append(unique, scored[i].row)
} }
} }
if len(unique) > 10 { if len(unique) > 10 {
unique = unique[:10] unique = unique[:10]
} }
return unique return unique
} }
@@ -379,58 +344,47 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
if len(results) == 0 { if len(results) == 0 {
return "No relevant information found in the vector database.", nil return "No relevant information found in the vector database.", nil
} }
var contextBuilder strings.Builder var contextBuilder strings.Builder
contextBuilder.WriteString("User Query: ") contextBuilder.WriteString("User Query: ")
contextBuilder.WriteString(query) contextBuilder.WriteString(query)
contextBuilder.WriteString("\n\nRetrieved Context:\n") contextBuilder.WriteString("\n\nRetrieved Context:\n")
for i, row := range results { for i, row := range results {
contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName)) fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName)
contextBuilder.WriteString(row.RawText) contextBuilder.WriteString(row.RawText)
contextBuilder.WriteString("\n\n") contextBuilder.WriteString("\n\n")
} }
contextBuilder.WriteString("Instructions: ") contextBuilder.WriteString("Instructions: ")
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ") contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
contextBuilder.WriteString("Extract only the most relevant information. ") contextBuilder.WriteString("Extract only the most relevant information. ")
contextBuilder.WriteString("If no relevant information is found, state that clearly. ") contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
contextBuilder.WriteString("Cite sources by filename when relevant. ") contextBuilder.WriteString("Cite sources by filename when relevant. ")
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.") contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
synthesisPrompt := contextBuilder.String() synthesisPrompt := contextBuilder.String()
emb, err := r.LineToVector(synthesisPrompt) emb, err := r.LineToVector(synthesisPrompt)
if err != nil { if err != nil {
r.logger.Error("failed to embed synthesis prompt", "error", err) r.logger.Error("failed to embed synthesis prompt", "error", err)
return "", err return "", err
} }
embResp := &models.EmbeddingResp{ embResp := &models.EmbeddingResp{
Embedding: emb, Embedding: emb,
Index: 0, Index: 0,
} }
topResults, err := r.SearchEmb(embResp) topResults, err := r.SearchEmb(embResp)
if err != nil { if err != nil {
r.logger.Error("failed to search for synthesis context", "error", err) r.logger.Error("failed to search for synthesis context", "error", err)
return "", err return "", err
} }
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt { if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
return topResults[0].RawText, nil return topResults[0].RawText, nil
} }
var finalAnswer strings.Builder var finalAnswer strings.Builder
finalAnswer.WriteString("Based on the retrieved context:\n\n") finalAnswer.WriteString("Based on the retrieved context:\n\n")
for i, row := range results { for i, row := range results {
if i >= 5 { if i >= 5 {
break break
} }
finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))) fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))
} }
return finalAnswer.String(), nil return finalAnswer.String(), nil
} }
@@ -444,10 +398,8 @@ func truncateString(s string, maxLen int) string {
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
refined := r.RefineQuery(query) refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined) variations := r.GenerateQueryVariations(refined)
allResults := make([]models.VectorRow, 0) allResults := make([]models.VectorRow, 0)
seen := make(map[string]bool) seen := make(map[string]bool)
for _, q := range variations { for _, q := range variations {
emb, err := r.LineToVector(q) emb, err := r.LineToVector(q)
if err != nil { if err != nil {
@@ -473,13 +425,10 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
} }
} }
} }
reranked := r.RerankResults(allResults, query) reranked := r.RerankResults(allResults, query)
if len(reranked) > limit { if len(reranked) > limit {
reranked = reranked[:limit] reranked = reranked[:limit]
} }
return reranked, nil return reranked, nil
} }

View File

@@ -28,7 +28,6 @@ func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorag
} }
} }
// SerializeVector converts []float32 to binary blob // SerializeVector converts []float32 to binary blob
func SerializeVector(vec []float32) []byte { func SerializeVector(vec []float32) []byte {
buf := make([]byte, len(vec)*4) // 4 bytes per float32 buf := make([]byte, len(vec)*4) // 4 bytes per float32
@@ -66,17 +65,14 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
// Serialize the embeddings to binary // Serialize the embeddings to binary
serializedEmbeddings := SerializeVector(row.Embeddings) serializedEmbeddings := SerializeVector(row.Embeddings)
query := fmt.Sprintf( query := fmt.Sprintf(
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
tableName, tableName,
) )
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
return err return err
} }
return nil return nil
} }
@@ -95,11 +91,9 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
4096: true, 4096: true,
5120: true, 5120: true,
} }
if supportedSizes[size] { if supportedSizes[size] {
return fmt.Sprintf("embeddings_%d", size), nil return fmt.Sprintf("embeddings_%d", size), nil
} }
return "", fmt.Errorf("no table for embedding size of %d", size) return "", fmt.Errorf("no table for embedding size of %d", size)
} }
@@ -126,9 +120,7 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
vector models.VectorRow vector models.VectorRow
distance float32 distance float32
} }
var topResults []SearchResult var topResults []SearchResult
// Process vectors one by one to avoid loading everything into memory // Process vectors one by one to avoid loading everything into memory
for rows.Next() { for rows.Next() {
var ( var (
@@ -176,14 +168,12 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
result.vector.Distance = result.distance result.vector.Distance = result.distance
results = append(results, result.vector) results = append(results, result.vector)
} }
return results, nil return results, nil
} }
// ListFiles returns a list of all loaded files // ListFiles returns a list of all loaded files
func (vs *VectorStorage) ListFiles() ([]string, error) { func (vs *VectorStorage) ListFiles() ([]string, error) {
fileLists := make([][]string, 0) fileLists := make([][]string, 0)
// Query all supported tables and combine results // Query all supported tables and combine results
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
for _, size := range embeddingSizes { for _, size := range embeddingSizes {
@@ -219,14 +209,12 @@ func (vs *VectorStorage) ListFiles() ([]string, error) {
} }
} }
} }
return allFiles, nil return allFiles, nil
} }
// RemoveEmbByFileName removes all embeddings associated with a specific filename // RemoveEmbByFileName removes all embeddings associated with a specific filename
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
var errors []string var errors []string
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
for _, size := range embeddingSizes { for _, size := range embeddingSizes {
table := fmt.Sprintf("embeddings_%d", size) table := fmt.Sprintf("embeddings_%d", size)
@@ -235,11 +223,9 @@ func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
errors = append(errors, err.Error()) errors = append(errors, err.Error())
} }
} }
if len(errors) > 0 { if len(errors) > 0 {
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; ")) return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
} }
return nil return nil
} }
@@ -248,18 +234,15 @@ func cosineSimilarity(a, b []float32) float32 {
if len(a) != len(b) { if len(a) != len(b) {
return 0.0 return 0.0
} }
var dotProduct, normA, normB float32 var dotProduct, normA, normB float32
for i := 0; i < len(a); i++ { for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i] dotProduct += a[i] * b[i]
normA += a[i] * a[i] normA += a[i] * a[i]
normB += b[i] * b[i] normB += b[i] * b[i]
} }
if normA == 0 || normB == 0 { if normA == 0 || normB == 0 {
return 0.0 return 0.0
} }
return dotProduct / (sqrt(normA) * sqrt(normB)) return dotProduct / (sqrt(normA) * sqrt(normB))
} }
@@ -275,4 +258,3 @@ func sqrt(f float32) float32 {
} }
return guess return guess
} }

View File

@@ -103,7 +103,6 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
return nil return nil
} }
p := ProviderSQL{db: db, logger: logger} p := ProviderSQL{db: db, logger: logger}
p.Migrate() p.Migrate()
return p return p
} }

View File

@@ -73,12 +73,9 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
if err != nil { if err != nil {
return err return err
} }
serializedEmbeddings := SerializeVector(row.Embeddings) serializedEmbeddings := SerializeVector(row.Embeddings)
query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
_, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName) _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
return err return err
} }
@@ -87,27 +84,22 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
rows, err := p.db.Query(querySQL) rows, err := p.db.Query(querySQL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
type SearchResult struct { type SearchResult struct {
vector models.VectorRow vector models.VectorRow
distance float32 distance float32
} }
var topResults []SearchResult var topResults []SearchResult
for rows.Next() { for rows.Next() {
var ( var (
embeddingsBlob []byte embeddingsBlob []byte
slug, rawText, fileName string slug, rawText, fileName string
) )
if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
continue continue
} }
@@ -152,7 +144,6 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
result.vector.Distance = result.distance result.vector.Distance = result.distance
results[i] = result.vector results[i] = result.vector
} }
return results, nil return results, nil
} }
@@ -161,18 +152,15 @@ func cosineSimilarity(a, b []float32) float32 {
if len(a) != len(b) { if len(a) != len(b) {
return 0.0 return 0.0
} }
var dotProduct, normA, normB float32 var dotProduct, normA, normB float32
for i := 0; i < len(a); i++ { for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i] dotProduct += a[i] * b[i]
normA += a[i] * a[i] normA += a[i] * a[i]
normB += b[i] * b[i] normB += b[i] * b[i]
} }
if normA == 0 || normB == 0 { if normA == 0 || normB == 0 {
return 0.0 return 0.0
} }
return dotProduct / (sqrt(normA) * sqrt(normB)) return dotProduct / (sqrt(normA) * sqrt(normB))
} }
@@ -229,13 +217,11 @@ func (p ProviderSQL) ListFiles() ([]string, error) {
} }
} }
} }
return allFiles, nil return allFiles, nil
} }
func (p ProviderSQL) RemoveEmbByFileName(filename string) error { func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
var errors []string var errors []string
tableNames := []string{ tableNames := []string{
"embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536", "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
"embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120", "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
@@ -246,10 +232,8 @@ func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
errors = append(errors, err.Error()) errors = append(errors, err.Error())
} }
} }
if len(errors) > 0 { if len(errors) > 0 {
return fmt.Errorf("errors occurred: %v", errors) return fmt.Errorf("errors occurred: %v", errors)
} }
return nil return nil
} }

View File

@@ -287,7 +287,6 @@ func makeRAGTable(fileList []string, loadedFiles []string) *tview.Flex {
}) })
} }
} }
rows := len(ragFiles) rows := len(ragFiles)
cols := 4 // File Name | Preview | Action | Delete cols := 4 // File Name | Preview | Action | Delete
fileTable := tview.NewTable(). fileTable := tview.NewTable().