Chore: linter complaints

This commit is contained in:
Grail Finder
2026-02-25 20:06:56 +03:00
parent 4f07994bdc
commit 888c9fec65
12 changed files with 11 additions and 123 deletions

View File

@@ -131,7 +131,6 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
}
embeddings[data.Index] = data.Embedding
}
return embeddings, nil
}

View File

@@ -95,9 +95,7 @@ func extractTextFromEpub(fpath string) (string, error) {
return "", fmt.Errorf("failed to open epub: %w", err)
}
defer r.Close()
var sb strings.Builder
for _, f := range r.File {
ext := strings.ToLower(path.Ext(f.Name))
if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" {
@@ -129,7 +127,6 @@ func extractTextFromEpub(fpath string) (string, error) {
sb.WriteString(stripHTML(string(buf)))
}
}
if sb.Len() == 0 {
return "", errors.New("no content extracted from epub")
}

View File

@@ -36,7 +36,6 @@ type RAG struct {
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
// Initialize with API embedder by default, could be configurable later
embedder := NewAPIEmbedder(l, cfg)
rag := &RAG{
logger: l,
store: s,
@@ -205,29 +204,22 @@ var (
func (r *RAG) RefineQuery(query string) string {
original := query
query = strings.TrimSpace(query)
if len(query) == 0 {
return original
}
if len(query) <= 3 {
return original
}
query = strings.ToLower(query)
for _, stopWord := range stopWords {
wordPattern := `\b` + stopWord + `\b`
re := regexp.MustCompile(wordPattern)
query = re.ReplaceAllString(query, "")
}
query = strings.TrimSpace(query)
if len(query) < 5 {
return original
}
if queryRefinementPattern.MatchString(original) {
cleaned := queryRefinementPattern.ReplaceAllString(original, "")
cleaned = strings.TrimSpace(cleaned)
@@ -235,23 +227,18 @@ func (r *RAG) RefineQuery(query string) string {
return cleaned
}
}
query = r.extractImportantPhrases(query)
if len(query) < 5 {
return original
}
return query
}
func (r *RAG) extractImportantPhrases(query string) string {
words := strings.Fields(query)
var important []string
for _, word := range words {
word = strings.Trim(word, ".,!?;:'\"()[]{}")
isImportant := false
for _, kw := range importantKeywords {
if strings.Contains(strings.ToLower(word), kw) {
@@ -259,45 +246,37 @@ func (r *RAG) extractImportantPhrases(query string) string {
break
}
}
if isImportant || len(word) > 3 {
important = append(important, word)
}
}
if len(important) == 0 {
return query
}
return strings.Join(important, " ")
}
func (r *RAG) GenerateQueryVariations(query string) []string {
variations := []string{query}
if len(query) < 5 {
return variations
}
parts := strings.Fields(query)
if len(parts) == 0 {
return variations
}
if len(parts) >= 2 {
trimmed := strings.Join(parts[:len(parts)-1], " ")
if len(trimmed) >= 5 {
variations = append(variations, trimmed)
}
}
if len(parts) >= 2 {
trimmed := strings.Join(parts[1:], " ")
if len(trimmed) >= 5 {
variations = append(variations, trimmed)
}
}
if !strings.HasSuffix(query, " explanation") {
variations = append(variations, query+" explanation")
}
@@ -310,7 +289,6 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if !strings.HasSuffix(query, " summary") {
variations = append(variations, query+" summary")
}
return variations
}
@@ -319,21 +297,16 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
row models.VectorRow
distance float32
}
scored := make([]scoredResult, 0, len(results))
for i := range results {
row := results[i]
score := float32(0)
rawTextLower := strings.ToLower(row.RawText)
queryLower := strings.ToLower(query)
if strings.Contains(rawTextLower, queryLower) {
score += 10
}
queryWords := strings.Fields(queryLower)
matchCount := 0
for _, word := range queryWords {
@@ -344,34 +317,26 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
if len(queryWords) > 0 {
score += float32(matchCount) / float32(len(queryWords)) * 5
}
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
score += 3
}
distance := row.Distance - score/100
scored = append(scored, scoredResult{row: row, distance: distance})
}
sort.Slice(scored, func(i, j int) bool {
return scored[i].distance < scored[j].distance
})
unique := make([]models.VectorRow, 0)
seen := make(map[string]bool)
for i := range scored {
if !seen[scored[i].row.Slug] {
seen[scored[i].row.Slug] = true
unique = append(unique, scored[i].row)
}
}
if len(unique) > 10 {
unique = unique[:10]
}
return unique
}
@@ -379,58 +344,47 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
if len(results) == 0 {
return "No relevant information found in the vector database.", nil
}
var contextBuilder strings.Builder
contextBuilder.WriteString("User Query: ")
contextBuilder.WriteString(query)
contextBuilder.WriteString("\n\nRetrieved Context:\n")
for i, row := range results {
contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName))
fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName)
contextBuilder.WriteString(row.RawText)
contextBuilder.WriteString("\n\n")
}
contextBuilder.WriteString("Instructions: ")
contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
contextBuilder.WriteString("Extract only the most relevant information. ")
contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
contextBuilder.WriteString("Cite sources by filename when relevant. ")
contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
synthesisPrompt := contextBuilder.String()
emb, err := r.LineToVector(synthesisPrompt)
if err != nil {
r.logger.Error("failed to embed synthesis prompt", "error", err)
return "", err
}
embResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0,
}
topResults, err := r.SearchEmb(embResp)
if err != nil {
r.logger.Error("failed to search for synthesis context", "error", err)
return "", err
}
if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
return topResults[0].RawText, nil
}
var finalAnswer strings.Builder
finalAnswer.WriteString("Based on the retrieved context:\n\n")
for i, row := range results {
if i >= 5 {
break
}
finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)))
fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))
}
return finalAnswer.String(), nil
}
@@ -444,10 +398,8 @@ func truncateString(s string, maxLen int) string {
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined)
allResults := make([]models.VectorRow, 0)
seen := make(map[string]bool)
for _, q := range variations {
emb, err := r.LineToVector(q)
if err != nil {
@@ -473,13 +425,10 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
}
}
}
reranked := r.RerankResults(allResults, query)
if len(reranked) > limit {
reranked = reranked[:limit]
}
return reranked, nil
}

View File

@@ -28,7 +28,6 @@ func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorag
}
}
// SerializeVector converts []float32 to binary blob
func SerializeVector(vec []float32) []byte {
buf := make([]byte, len(vec)*4) // 4 bytes per float32
@@ -66,17 +65,14 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error {
// Serialize the embeddings to binary
serializedEmbeddings := SerializeVector(row.Embeddings)
query := fmt.Sprintf(
"INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)",
tableName,
)
if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil {
vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug)
return err
}
return nil
}
@@ -86,20 +82,18 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) {
// Check if we support this embedding size
supportedSizes := map[int]bool{
384: true,
768: true,
1024: true,
1536: true,
2048: true,
3072: true,
4096: true,
5120: true,
384: true,
768: true,
1024: true,
1536: true,
2048: true,
3072: true,
4096: true,
5120: true,
}
if supportedSizes[size] {
return fmt.Sprintf("embeddings_%d", size), nil
}
return "", fmt.Errorf("no table for embedding size of %d", size)
}
@@ -126,9 +120,7 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
vector models.VectorRow
distance float32
}
var topResults []SearchResult
// Process vectors one by one to avoid loading everything into memory
for rows.Next() {
var (
@@ -176,14 +168,12 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err
result.vector.Distance = result.distance
results = append(results, result.vector)
}
return results, nil
}
// ListFiles returns a list of all loaded files
func (vs *VectorStorage) ListFiles() ([]string, error) {
fileLists := make([][]string, 0)
// Query all supported tables and combine results
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
for _, size := range embeddingSizes {
@@ -219,14 +209,12 @@ func (vs *VectorStorage) ListFiles() ([]string, error) {
}
}
}
return allFiles, nil
}
// RemoveEmbByFileName removes all embeddings associated with a specific filename
func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
var errors []string
embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120}
for _, size := range embeddingSizes {
table := fmt.Sprintf("embeddings_%d", size)
@@ -235,11 +223,9 @@ func (vs *VectorStorage) RemoveEmbByFileName(filename string) error {
errors = append(errors, err.Error())
}
}
if len(errors) > 0 {
return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; "))
}
return nil
}
@@ -248,18 +234,15 @@ func cosineSimilarity(a, b []float32) float32 {
if len(a) != len(b) {
return 0.0
}
var dotProduct, normA, normB float32
for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0.0
}
return dotProduct / (sqrt(normA) * sqrt(normB))
}
@@ -275,4 +258,3 @@ func sqrt(f float32) float32 {
}
return guess
}