Enha: rewrite (upsert) memory

This commit is contained in:
Grail Finder
2025-02-09 18:21:34 +03:00
parent c857661393
commit 5468053908
4 changed files with 39 additions and 14 deletions

View File

@@ -9,7 +9,13 @@ type Memories interface {
}
func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) {
query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;"
query := `
INSERT INTO memories (agent, topic, mind)
VALUES (:agent, :topic, :mind)
ON CONFLICT (agent, topic) DO UPDATE
SET mind = excluded.mind,
updated_at = CURRENT_TIMESTAMP
RETURNING *;`
stmt, err := p.db.PrepareNamed(query)
if err != nil {
p.logger.Error("failed to prepare stmt", "query", query, "error", err)
@@ -19,7 +25,7 @@ func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) {
var memory models.Memory
err = stmt.Get(&memory, m)
if err != nil {
p.logger.Error("failed to insert memory", "query", query, "error", err)
p.logger.Error("failed to upsert memory", "query", query, "error", err)
return nil, err
}
return &memory, nil

View File

@@ -38,22 +38,27 @@ CREATE TABLE IF NOT EXISTS memories (
logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)),
}
// Create a sample memory for testing
sampleMemory := &models.Memory{
sampleMemory := models.Memory{
Agent: "testAgent",
Topic: "testTopic",
Mind: "testMind",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
sampleMemoryRewrite := models.Memory{
Agent: "testAgent",
Topic: "testTopic",
Mind: "same topic, new mind",
}
cases := []struct {
memory *models.Memory
memories []models.Memory
}{
{memory: sampleMemory},
{memories: []models.Memory{sampleMemory, sampleMemoryRewrite}},
}
for i, tc := range cases {
t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) {
// Recall topics: get no rows
topics, err := provider.RecallTopics(tc.memory.Agent)
topics, err := provider.RecallTopics(tc.memories[0].Agent)
if err != nil {
t.Fatalf("Failed to recall topics: %v", err)
}
@@ -61,12 +66,12 @@ CREATE TABLE IF NOT EXISTS memories (
t.Fatalf("Expected no topics, got: %v", topics)
}
// Memorise
_, err = provider.Memorise(tc.memory)
_, err = provider.Memorise(&tc.memories[0])
if err != nil {
t.Fatalf("Failed to memorise: %v", err)
}
// Recall topics: has topics
topics, err = provider.RecallTopics(tc.memory.Agent)
topics, err = provider.RecallTopics(tc.memories[0].Agent)
if err != nil {
t.Fatalf("Failed to recall topics: %v", err)
}
@@ -74,12 +79,20 @@ CREATE TABLE IF NOT EXISTS memories (
t.Fatalf("Expected topics, got none")
}
// Recall
content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic)
content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic)
if err != nil {
t.Fatalf("Failed to recall: %v", err)
}
if content != tc.memory.Mind {
t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content)
if content != tc.memories[0].Mind {
t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content)
}
// rewrite mind of same agent-topic
newMem, err := provider.Memorise(&tc.memories[1])
if err != nil {
t.Fatalf("Failed to memorise: %v", err)
}
if newMem.Mind == tc.memories[0].Mind {
t.Fatalf("Failed to change mind: %v", newMem.Mind)
}
})
}