Feat: migration on startup

This commit is contained in:
Grail Finder
2024-11-20 13:21:51 +03:00
parent 74669b58fe
commit 8ae4d075c4
5 changed files with 69 additions and 9 deletions

View File

@@ -19,3 +19,4 @@
### FIX: ### FIX:
- bot responding (or haninging) blocks everything; + - bot responding (or haninging) blocks everything; +
- programm requires history folder, but it is .gitignore; + - programm requires history folder, but it is .gitignore; +
- at first run chat table does not exist; run migrations sql on startup; +

3
bot.go
View File

@@ -276,10 +276,9 @@ func init() {
if err := os.MkdirAll(historyDir, os.ModePerm); err != nil { if err := os.MkdirAll(historyDir, os.ModePerm); err != nil {
panic(err) panic(err)
} }
store = storage.NewProviderSQL("test.db")
// defer file.Close()
logger = slog.New(slog.NewTextHandler(file, nil)) logger = slog.New(slog.NewTextHandler(file, nil))
logger.Info("test msg") logger.Info("test msg")
store = storage.NewProviderSQL("test.db", logger)
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md // https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
// load all chats in memory // load all chats in memory
loadHistoryChats() loadHistoryChats()

59
storage/migrate.go Normal file
View File

@@ -0,0 +1,59 @@
package storage
import (
"embed"
"fmt"
"io/fs"
"strings"
)
//go:embed migrations/*
var migrationsFS embed.FS
func (p *ProviderSQL) Migrate() {
// Get the embedded filesystem
migrationsDir, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
p.logger.Error("Failed to get embedded migrations directory;", "error", err)
}
// List all .up.sql files
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
p.logger.Error("Failed to read migrations directory;", "error", err)
}
// Execute each .up.sql file
for _, file := range files {
if strings.HasSuffix(file.Name(), ".up.sql") {
err := p.executeMigration(migrationsDir, file.Name())
if err != nil {
p.logger.Error("Failed to execute migration %s: %v", file.Name(), err)
}
}
}
p.logger.Info("All migrations executed successfully!")
}
func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error {
// Open the migration file
migrationFile, err := migrationsDir.Open(fileName)
if err != nil {
return fmt.Errorf("failed to open migration file %s: %w", fileName, err)
}
defer migrationFile.Close()
// Read the migration file content
migrationContent, err := fs.ReadFile(migrationsDir, fileName)
if err != nil {
return fmt.Errorf("failed to read migration file %s: %w", fileName, err)
}
// Execute the migration content
return p.executeSQL(migrationContent)
}
func (p *ProviderSQL) executeSQL(sqlContent []byte) error {
// Connect to the database (example using a simple connection)
_, err := p.db.Exec(string(sqlContent))
if err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
return nil
}

View File

@@ -1,4 +1,4 @@
CREATE TABLE chat ( CREATE TABLE IF NOT EXISTS chat (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL, name TEXT NOT NULL,
msgs TEXT NOT NULL, -- Store messages as a comma-separated string msgs TEXT NOT NULL, -- Store messages as a comma-separated string

View File

@@ -2,7 +2,7 @@ package storage
import ( import (
"elefant/models" "elefant/models"
"fmt" "log/slog"
_ "github.com/glebarez/go-sqlite" _ "github.com/glebarez/go-sqlite"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@@ -17,7 +17,8 @@ type ChatHistory interface {
} }
type ProviderSQL struct { type ProviderSQL struct {
db *sqlx.DB db *sqlx.DB
logger *slog.Logger
} }
func (p ProviderSQL) ListChats() ([]models.Chat, error) { func (p ProviderSQL) ListChats() ([]models.Chat, error) {
@@ -60,13 +61,13 @@ func (p ProviderSQL) RemoveChat(id uint32) error {
return err return err
} }
func NewProviderSQL(dbPath string) ChatHistory { func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory {
db, err := sqlx.Open("sqlite", dbPath) db, err := sqlx.Open("sqlite", dbPath)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// get SQLite version // get SQLite version
res := db.QueryRow("select sqlite_version()") p := ProviderSQL{db: db, logger: logger}
fmt.Println(res) p.Migrate()
return ProviderSQL{db: db} return p
} }