grab/db/migrations.go

166 lines
5.3 KiB
Go
Raw Normal View History

// Migrations provides an agnostic way to run database migrations for a service.
// Just provide a database connection and FS full of migration scripts, and
// call PrepareDatabase() and RunMigrations() (both idempotent) and you're off
// to the races.
//
// Currently assumes MySQL, but will be expanded with more DB backends Soon™.
package grabdb
import (
"context"
"database/sql"
"fmt"
"io/fs"
"strconv"
"strings"
"time"
)
type Migrations struct {
Connection *sql.DB
VersionTable string
MigrationsFS fs.ReadDirFS
}
type migration struct {
id int
name string
content []byte
}
// NewMigrations returns a new Migration struct that provides db migration functions.
// fs must be a filesystem that has a folder of .sql files in the root,
// named in the following pattern: 00-name.sql, 01-name.sql, 02-name.sql...
// Migrations will be applied in numerical sort order. Migrations will only be applied once.
// Behavior of duplicated integers is undefined, but probably unpleasant.
func NewMigrations(connection *sql.DB, dbType string, tablename string, fs fs.ReadDirFS) (*Migrations, error) {
if connection == nil {
return nil, fmt.Errorf("db connection must be provided")
}
if tablename == "" {
return nil, fmt.Errorf("tablename must be provided")
}
if fs == nil {
return nil, fmt.Errorf("filesystem must be provided")
}
return &Migrations{
Connection: connection,
VersionTable: tablename,
MigrationsFS: fs,
}, nil
}
func (m *Migrations) PrepareDatabase(ctx context.Context) error {
if m.Connection == nil || m.VersionTable == "" {
return fmt.Errorf("uninitialized migration struct")
}
tablecheck := `SELECT count(*) AS count
FROM information_schema.TABLES
WHERE TABLE_NAME = '` + m.VersionTable + `'
AND TABLE_SCHEMA in (SELECT DATABASE());`
tableschema := `CREATE TABLE ` + m.VersionTable + `(
id INT NOT NULL,
name VARCHAR(100) NOT NULL,
datetime DATE,
PRIMARY KEY (id))`
var versionTableExists int
m.Connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists)
if versionTableExists != 0 {
return nil
}
_, err := m.Connection.ExecContext(ctx, tableschema)
return err
}
func (m *Migrations) GetLatestMigration(ctx context.Context) (int, error) {
if m.Connection == nil || m.VersionTable == "" {
return 0, fmt.Errorf("uninitialized migration struct")
}
var latestMigration int
err := m.Connection.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM "+m.VersionTable).Scan(&latestMigration)
return latestMigration, err
}
func (m *Migrations) RunMigrations(ctx context.Context) (int, int, error) {
if m.Connection == nil || m.VersionTable == "" {
return 0, 0, fmt.Errorf("uninitialized migration struct")
}
migrations := map[int]migration{}
dir, err := m.MigrationsFS.ReadDir("/")
if err != nil {
return 0, 0, fmt.Errorf("could not load migrations from fs directory: %w", err)
}
for f := range dir {
if dir[f].Type().IsRegular() {
mig := migration{}
id, name, err := parseMigrationFileName(dir[f].Name())
if err != nil {
return 0, 0, fmt.Errorf("could not parse migration from fs directory: %w", err)
}
mig.id, mig.name = id, name
mig.content, err = fs.ReadFile(m.MigrationsFS, "/"+dir[f].Name())
if err != nil {
return 0, 0, fmt.Errorf("could not load migration from fs: %w", err)
}
migrations[mig.id] = mig
}
}
latestMigrationRan, err := m.GetLatestMigration(ctx)
if err != nil {
return 0, 0, fmt.Errorf("unable to determine most recent migration: %w", err)
}
// exit if nothing to do (that is, there's no greater migration ID)
if _, ok := migrations[latestMigrationRan+1]; !ok {
return latestMigrationRan, 0, nil
}
// loop over and apply migrations if required
tx, err := m.Connection.BeginTx(ctx, nil)
if err != nil {
return latestMigrationRan, 0, fmt.Errorf("unable to open transaction: %w", err)
}
migrationsRun := 0
for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] {
mig := migrations[latestMigrationRan+1]
_, err := tx.ExecContext(ctx, string(mig.content))
if err != nil {
nestederr := tx.Rollback()
if nestederr != nil {
return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, nestederr)
}
return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, err)
}
_, err = tx.ExecContext(ctx, "INSERT INTO "+m.VersionTable+" (id, name, datetime) VALUES (?, ?, ?)", mig.id, mig.name, time.Now())
if err != nil {
nestederr := tx.Rollback()
if nestederr != nil {
return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, nestederr)
}
return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, err)
}
latestMigrationRan = latestMigrationRan + 1
migrationsRun = migrationsRun + 1
}
err = tx.Commit()
if err != nil {
return latestMigrationRan, migrationsRun, fmt.Errorf("failure committing transaction: %w", err)
}
return latestMigrationRan, migrationsRun, nil
}
func parseMigrationFileName(filename string) (int, string, error) {
sp := strings.SplitN(filename, "-", 2)
i, err := strconv.Atoi(sp[0])
if err != nil {
return 0, "", err
}
tr := strings.TrimSuffix(sp[1], ".sql")
return i, tr, nil
}