107 lines
2.1 KiB
Go
107 lines
2.1 KiB
Go
package repos
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
type AllRepos interface {
|
|
RoomsRepo
|
|
ActionsRepo
|
|
PlayersRepo
|
|
SessionsRepo
|
|
WordCardsRepo
|
|
InitTx(ctx context.Context) (context.Context, *sqlx.Tx, error)
|
|
}
|
|
|
|
type RepoProvider struct {
|
|
DB *sqlx.DB
|
|
mu sync.RWMutex
|
|
pathToDB string
|
|
}
|
|
|
|
func NewRepoProvider(pathToDB string) *RepoProvider {
|
|
db, err := sqlx.Connect("sqlite3", pathToDB)
|
|
if err != nil {
|
|
slog.Error("Unable to connect to database", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
slog.Info("Successfully connected to database")
|
|
rp := &RepoProvider{
|
|
DB: db,
|
|
pathToDB: pathToDB,
|
|
}
|
|
|
|
go rp.pingLoop()
|
|
|
|
return rp
|
|
}
|
|
|
|
func (rp *RepoProvider) pingLoop() {
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
if err := rp.pingDB(); err != nil {
|
|
slog.Error("Database ping failed, attempting to reconnect...", "error", err)
|
|
rp.reconnect()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rp *RepoProvider) pingDB() error {
|
|
rp.mu.RLock()
|
|
defer rp.mu.RUnlock()
|
|
if rp.DB == nil {
|
|
return os.ErrClosed
|
|
}
|
|
return rp.DB.Ping()
|
|
}
|
|
|
|
func (rp *RepoProvider) reconnect() {
|
|
rp.mu.Lock()
|
|
defer rp.mu.Unlock()
|
|
|
|
// Double-check if connection is still down
|
|
if rp.DB != nil {
|
|
if err := rp.DB.Ping(); err == nil {
|
|
slog.Info("Database connection already re-established.")
|
|
return
|
|
}
|
|
// if ping fails, we continue to reconnect
|
|
rp.DB.Close() // close old connection
|
|
}
|
|
|
|
slog.Info("Reconnecting to database...")
|
|
db, err := sqlx.Connect("sqlite3", rp.pathToDB)
|
|
if err != nil {
|
|
slog.Error("Failed to reconnect to database", "error", err)
|
|
rp.DB = nil // make sure DB is nil if connection failed
|
|
return
|
|
}
|
|
|
|
rp.DB = db
|
|
slog.Info("Successfully reconnected to database")
|
|
}
|
|
|
|
func getDB(ctx context.Context, db *sqlx.DB) sqlx.ExtContext {
|
|
if tx, ok := ctx.Value("tx").(*sqlx.Tx); ok {
|
|
return tx
|
|
}
|
|
return db
|
|
}
|
|
|
|
func (p *RepoProvider) InitTx(ctx context.Context) (context.Context, *sqlx.Tx, error) {
|
|
tx, err := p.DB.BeginTxx(ctx, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return context.WithValue(ctx, "tx", tx), tx, nil
|
|
}
|