264 lines
6.4 KiB
Go
264 lines
6.4 KiB
Go
package storage
|
|
|
|
import (
|
|
"elefant/models"
|
|
"fmt"
|
|
"log"
|
|
"log/slog"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
|
|
_ "github.com/glebarez/go-sqlite"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/ncruces/go-sqlite3"
|
|
)
|
|
|
|
func TestMemories(t *testing.T) {
|
|
db, err := sqlx.Open("sqlite", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to open SQLite in-memory database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS memories (
|
|
agent TEXT NOT NULL,
|
|
topic TEXT NOT NULL,
|
|
mind TEXT NOT NULL,
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
PRIMARY KEY (agent, topic)
|
|
);`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create chat table: %v", err)
|
|
}
|
|
provider := ProviderSQL{
|
|
db: db,
|
|
logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)),
|
|
}
|
|
// Create a sample memory for testing
|
|
sampleMemory := models.Memory{
|
|
Agent: "testAgent",
|
|
Topic: "testTopic",
|
|
Mind: "testMind",
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
sampleMemoryRewrite := models.Memory{
|
|
Agent: "testAgent",
|
|
Topic: "testTopic",
|
|
Mind: "same topic, new mind",
|
|
}
|
|
cases := []struct {
|
|
memories []models.Memory
|
|
}{
|
|
{memories: []models.Memory{sampleMemory, sampleMemoryRewrite}},
|
|
}
|
|
for i, tc := range cases {
|
|
t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) {
|
|
// Recall topics: get no rows
|
|
topics, err := provider.RecallTopics(tc.memories[0].Agent)
|
|
if err != nil {
|
|
t.Fatalf("Failed to recall topics: %v", err)
|
|
}
|
|
if len(topics) != 0 {
|
|
t.Fatalf("Expected no topics, got: %v", topics)
|
|
}
|
|
// Memorise
|
|
_, err = provider.Memorise(&tc.memories[0])
|
|
if err != nil {
|
|
t.Fatalf("Failed to memorise: %v", err)
|
|
}
|
|
// Recall topics: has topics
|
|
topics, err = provider.RecallTopics(tc.memories[0].Agent)
|
|
if err != nil {
|
|
t.Fatalf("Failed to recall topics: %v", err)
|
|
}
|
|
if len(topics) == 0 {
|
|
t.Fatalf("Expected topics, got none")
|
|
}
|
|
// Recall
|
|
content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic)
|
|
if err != nil {
|
|
t.Fatalf("Failed to recall: %v", err)
|
|
}
|
|
if content != tc.memories[0].Mind {
|
|
t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content)
|
|
}
|
|
// rewrite mind of same agent-topic
|
|
newMem, err := provider.Memorise(&tc.memories[1])
|
|
if err != nil {
|
|
t.Fatalf("Failed to memorise: %v", err)
|
|
}
|
|
if newMem.Mind == tc.memories[0].Mind {
|
|
t.Fatalf("Failed to change mind: %v", newMem.Mind)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChatHistory(t *testing.T) {
|
|
// Create an in-memory SQLite database
|
|
db, err := sqlx.Open("sqlite", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to open SQLite in-memory database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
// Create the chat table
|
|
_, err = db.Exec(`
|
|
CREATE TABLE chats (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
msgs TEXT NOT NULL,
|
|
agent TEXT NOT NULL,
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
);`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create chat table: %v", err)
|
|
}
|
|
// Initialize the ProviderSQL struct
|
|
provider := ProviderSQL{db: db}
|
|
// List chats (should be empty)
|
|
chats, err := provider.ListChats()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list chats: %v", err)
|
|
}
|
|
if len(chats) != 0 {
|
|
t.Errorf("Expected 0 chats, got %d", len(chats))
|
|
}
|
|
// Upsert a chat
|
|
chat := &models.Chat{
|
|
ID: 1,
|
|
Name: "Test Chat",
|
|
Msgs: "Hello World",
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
updatedChat, err := provider.UpsertChat(chat)
|
|
if err != nil {
|
|
t.Fatalf("Failed to upsert chat: %v", err)
|
|
}
|
|
if updatedChat == nil {
|
|
t.Errorf("Expected non-nil chat after upsert")
|
|
}
|
|
// Get chat by ID
|
|
fetchedChat, err := provider.GetChatByID(chat.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get chat by ID: %v", err)
|
|
}
|
|
if fetchedChat == nil {
|
|
t.Errorf("Expected non-nil chat after get")
|
|
}
|
|
if fetchedChat.Name != chat.Name {
|
|
t.Errorf("Expected chat name %s, got %s", chat.Name, fetchedChat.Name)
|
|
}
|
|
// List chats (should contain the upserted chat)
|
|
chats, err = provider.ListChats()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list chats: %v", err)
|
|
}
|
|
if len(chats) != 1 {
|
|
t.Errorf("Expected 1 chat, got %d", len(chats))
|
|
}
|
|
// Remove chat
|
|
err = provider.RemoveChat(chat.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to remove chat: %v", err)
|
|
}
|
|
// List chats (should be empty again)
|
|
chats, err = provider.ListChats()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list chats: %v", err)
|
|
}
|
|
if len(chats) != 0 {
|
|
t.Errorf("Expected 0 chats, got %d", len(chats))
|
|
}
|
|
}
|
|
|
|
func TestVecTable(t *testing.T) {
|
|
// healthcheck
|
|
db, err := sqlite3.Open(":memory:")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
stmt.Step()
|
|
log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
|
|
stmt.Close()
|
|
// migration
|
|
err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// data prep and insert
|
|
items := map[int][]float32{
|
|
1: {0.1, 0.1, 0.1, 0.1},
|
|
2: {0.2, 0.2, 0.2, 0.2},
|
|
3: {0.3, 0.3, 0.3, 0.3},
|
|
4: {0.4, 0.4, 0.4, 0.4},
|
|
5: {0.5, 0.5, 0.5, 0.5},
|
|
}
|
|
q := []float32{0.28, 0.3, 0.3, 0.3}
|
|
stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for id, values := range items {
|
|
v, err := sqlite_vec.SerializeFloat32(values)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
stmt.BindInt(1, id)
|
|
stmt.BindBlob(2, v)
|
|
stmt.BindText(3, "some_chat")
|
|
err = stmt.Exec()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
stmt.Reset()
|
|
}
|
|
stmt.Close()
|
|
// select | vec search
|
|
stmt, _, err = db.Prepare(`
|
|
SELECT
|
|
rowid,
|
|
distance,
|
|
embedding
|
|
FROM vec_items
|
|
WHERE embedding MATCH ?
|
|
ORDER BY distance
|
|
LIMIT 3
|
|
`)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
query, err := sqlite_vec.SerializeFloat32(q)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
stmt.BindBlob(1, query)
|
|
for stmt.Step() {
|
|
rowid := stmt.ColumnInt64(0)
|
|
distance := stmt.ColumnFloat(1)
|
|
emb := stmt.ColumnRawText(2)
|
|
floats := decodeUnsafe(emb)
|
|
log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
|
|
}
|
|
if err := stmt.Err(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = stmt.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = db.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|