261 lines
6.6 KiB
Go
261 lines
6.6 KiB
Go
package storage
|
|
|
|
import (
|
|
"gf-lt/models"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
_ "github.com/glebarez/go-sqlite"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
func TestMemories(t *testing.T) {
|
|
db, err := sqlx.Open("sqlite", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to open SQLite in-memory database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS memories (
|
|
agent TEXT NOT NULL,
|
|
topic TEXT NOT NULL,
|
|
mind TEXT NOT NULL,
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
PRIMARY KEY (agent, topic)
|
|
);`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create chat table: %v", err)
|
|
}
|
|
provider := ProviderSQL{
|
|
db: db,
|
|
logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)),
|
|
}
|
|
// Create a sample memory for testing
|
|
sampleMemory := models.Memory{
|
|
Agent: "testAgent",
|
|
Topic: "testTopic",
|
|
Mind: "testMind",
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
sampleMemoryRewrite := models.Memory{
|
|
Agent: "testAgent",
|
|
Topic: "testTopic",
|
|
Mind: "same topic, new mind",
|
|
}
|
|
cases := []struct {
|
|
memories []models.Memory
|
|
}{
|
|
{memories: []models.Memory{sampleMemory, sampleMemoryRewrite}},
|
|
}
|
|
for i, tc := range cases {
|
|
t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) {
|
|
// Recall topics: get no rows
|
|
topics, err := provider.RecallTopics(tc.memories[0].Agent)
|
|
if err != nil {
|
|
t.Fatalf("Failed to recall topics: %v", err)
|
|
}
|
|
if len(topics) != 0 {
|
|
t.Fatalf("Expected no topics, got: %v", topics)
|
|
}
|
|
// Memorise
|
|
_, err = provider.Memorise(&tc.memories[0])
|
|
if err != nil {
|
|
t.Fatalf("Failed to memorise: %v", err)
|
|
}
|
|
// Recall topics: has topics
|
|
topics, err = provider.RecallTopics(tc.memories[0].Agent)
|
|
if err != nil {
|
|
t.Fatalf("Failed to recall topics: %v", err)
|
|
}
|
|
if len(topics) == 0 {
|
|
t.Fatalf("Expected topics, got none")
|
|
}
|
|
// Recall
|
|
content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic)
|
|
if err != nil {
|
|
t.Fatalf("Failed to recall: %v", err)
|
|
}
|
|
if content != tc.memories[0].Mind {
|
|
t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content)
|
|
}
|
|
// rewrite mind of same agent-topic
|
|
newMem, err := provider.Memorise(&tc.memories[1])
|
|
if err != nil {
|
|
t.Fatalf("Failed to memorise: %v", err)
|
|
}
|
|
if newMem.Mind == tc.memories[0].Mind {
|
|
t.Fatalf("Failed to change mind: %v", newMem.Mind)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChatHistory(t *testing.T) {
|
|
// Create an in-memory SQLite database
|
|
db, err := sqlx.Open("sqlite", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to open SQLite in-memory database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
// Create the chat table
|
|
_, err = db.Exec(`
|
|
CREATE TABLE chats (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
msgs TEXT NOT NULL,
|
|
agent TEXT NOT NULL,
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
);`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create chat table: %v", err)
|
|
}
|
|
// Initialize the ProviderSQL struct
|
|
provider := ProviderSQL{db: db}
|
|
// List chats (should be empty)
|
|
chats, err := provider.ListChats()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list chats: %v", err)
|
|
}
|
|
if len(chats) != 0 {
|
|
t.Errorf("Expected 0 chats, got %d", len(chats))
|
|
}
|
|
// Upsert a chat
|
|
chat := &models.Chat{
|
|
ID: 1,
|
|
Name: "Test Chat",
|
|
Msgs: "Hello World",
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
updatedChat, err := provider.UpsertChat(chat)
|
|
if err != nil {
|
|
t.Fatalf("Failed to upsert chat: %v", err)
|
|
}
|
|
if updatedChat == nil {
|
|
t.Errorf("Expected non-nil chat after upsert")
|
|
}
|
|
// Get chat by ID
|
|
fetchedChat, err := provider.GetChatByID(chat.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get chat by ID: %v", err)
|
|
}
|
|
if fetchedChat == nil {
|
|
t.Errorf("Expected non-nil chat after get")
|
|
}
|
|
if fetchedChat.Name != chat.Name {
|
|
t.Errorf("Expected chat name %s, got %s", chat.Name, fetchedChat.Name)
|
|
}
|
|
// List chats (should contain the upserted chat)
|
|
chats, err = provider.ListChats()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list chats: %v", err)
|
|
}
|
|
if len(chats) != 1 {
|
|
t.Errorf("Expected 1 chat, got %d", len(chats))
|
|
}
|
|
// Remove chat
|
|
err = provider.RemoveChat(chat.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to remove chat: %v", err)
|
|
}
|
|
// List chats (should be empty again)
|
|
chats, err = provider.ListChats()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list chats: %v", err)
|
|
}
|
|
if len(chats) != 0 {
|
|
t.Errorf("Expected 0 chats, got %d", len(chats))
|
|
}
|
|
}
|
|
|
|
// func TestVecTable(t *testing.T) {
|
|
// // healthcheck
|
|
// db, err := sqlite3.Open(":memory:")
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// stmt.Step()
|
|
// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
|
|
// stmt.Close()
|
|
// // migration
|
|
// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// // data prep and insert
|
|
// items := map[int][]float32{
|
|
// 1: {0.1, 0.1, 0.1, 0.1},
|
|
// 2: {0.2, 0.2, 0.2, 0.2},
|
|
// 3: {0.3, 0.3, 0.3, 0.3},
|
|
// 4: {0.4, 0.4, 0.4, 0.4},
|
|
// 5: {0.5, 0.5, 0.5, 0.5},
|
|
// }
|
|
// q := []float32{0.4, 0.3, 0.3, 0.3}
|
|
// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// for id, values := range items {
|
|
// v, err := sqlite_vec.SerializeFloat32(values)
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// stmt.BindInt(1, id)
|
|
// stmt.BindBlob(2, v)
|
|
// stmt.BindText(3, "some_chat")
|
|
// err = stmt.Exec()
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// stmt.Reset()
|
|
// }
|
|
// stmt.Close()
|
|
// // select | vec search
|
|
// stmt, _, err = db.Prepare(`
|
|
// SELECT
|
|
// rowid,
|
|
// distance,
|
|
// embedding
|
|
// FROM vec_items
|
|
// WHERE embedding MATCH ?
|
|
// ORDER BY distance
|
|
// LIMIT 3
|
|
// `)
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// query, err := sqlite_vec.SerializeFloat32(q)
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// stmt.BindBlob(1, query)
|
|
// for stmt.Step() {
|
|
// rowid := stmt.ColumnInt64(0)
|
|
// distance := stmt.ColumnFloat(1)
|
|
// emb := stmt.ColumnRawText(2)
|
|
// floats := decodeUnsafe(emb)
|
|
// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
|
|
// }
|
|
// if err := stmt.Err(); err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// err = stmt.Close()
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// err = db.Close()
|
|
// if err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// }
|