
166 lines
5.3 KiB
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Migrations provides an agnostic way to run database migrations for a service.
// Just provide a database connection and FS full of migration scripts, and
// call PrepareDatabase() and RunMigrations() (both idempotent) and you're off
// to the races.
// Currently assumes MySQL, but will be expanded with more DB backends Soon™.
package grabdb
import (
type Migrations struct {
Connection *sql.DB
VersionTable string
MigrationsFS fs.ReadDirFS
type migration struct {
id int
name string
content []byte
// NewMigrations returns a new Migration struct that provides db migration functions.
// fs must be a filesystem that has a folder of .sql files in the root,
// named in the following pattern: 00-name.sql, 01-name.sql, 02-name.sql...
// Migrations will be applied in numerical sort order. Migrations will only be applied once.
// Behavior of duplicated integers is undefined, but probably unpleasant.
func NewMigrations(connection *sql.DB, dbType string, tablename string, fs fs.ReadDirFS) (*Migrations, error) {
if connection == nil {
return nil, fmt.Errorf("db connection must be provided")
if tablename == "" {
return nil, fmt.Errorf("tablename must be provided")
if fs == nil {
return nil, fmt.Errorf("filesystem must be provided")
return &Migrations{
Connection: connection,
VersionTable: tablename,
MigrationsFS: fs,
}, nil
func (m *Migrations) PrepareDatabase(ctx context.Context) error {
if m.Connection == nil || m.VersionTable == "" {
return fmt.Errorf("uninitialized migration struct")
tablecheck := `SELECT count(*) AS count
FROM information_schema.TABLES
WHERE TABLE_NAME = '` + m.VersionTable + `'
tableschema := `CREATE TABLE ` + m.VersionTable + `(
datetime DATE,
var versionTableExists int
m.Connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists)
if versionTableExists != 0 {
return nil
_, err := m.Connection.ExecContext(ctx, tableschema)
return err
func (m *Migrations) GetLatestMigration(ctx context.Context) (int, error) {
if m.Connection == nil || m.VersionTable == "" {
return 0, fmt.Errorf("uninitialized migration struct")
var latestMigration int
err := m.Connection.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM "+m.VersionTable).Scan(&latestMigration)
return latestMigration, err
func (m *Migrations) RunMigrations(ctx context.Context) (int, int, error) {
if m.Connection == nil || m.VersionTable == "" {
return 0, 0, fmt.Errorf("uninitialized migration struct")
migrations := map[int]migration{}
dir, err := m.MigrationsFS.ReadDir("/")
if err != nil {
return 0, 0, fmt.Errorf("could not load migrations from fs directory: %w", err)
for f := range dir {
if dir[f].Type().IsRegular() {
mig := migration{}
id, name, err := parseMigrationFileName(dir[f].Name())
if err != nil {
return 0, 0, fmt.Errorf("could not parse migration from fs directory: %w", err)
}, = id, name
mig.content, err = fs.ReadFile(m.MigrationsFS, "/"+dir[f].Name())
if err != nil {
return 0, 0, fmt.Errorf("could not load migration from fs: %w", err)
migrations[] = mig
latestMigrationRan, err := m.GetLatestMigration(ctx)
if err != nil {
return 0, 0, fmt.Errorf("unable to determine most recent migration: %w", err)
// exit if nothing to do (that is, there's no greater migration ID)
if _, ok := migrations[latestMigrationRan+1]; !ok {
return latestMigrationRan, 0, nil
// loop over and apply migrations if required
tx, err := m.Connection.BeginTx(ctx, nil)
if err != nil {
return latestMigrationRan, 0, fmt.Errorf("unable to open transaction: %w", err)
migrationsRun := 0
for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] {
mig := migrations[latestMigrationRan+1]
_, err := tx.ExecContext(ctx, string(mig.content))
if err != nil {
nestederr := tx.Rollback()
if nestederr != nil {
return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, nestederr)
return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, err)
_, err = tx.ExecContext(ctx, "INSERT INTO "+m.VersionTable+" (id, name, datetime) VALUES (?, ?, ?)",,, time.Now())
if err != nil {
nestederr := tx.Rollback()
if nestederr != nil {
return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, nestederr)
return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, err)
latestMigrationRan = latestMigrationRan + 1
migrationsRun = migrationsRun + 1
err = tx.Commit()
if err != nil {
return latestMigrationRan, migrationsRun, fmt.Errorf("failure committing transaction: %w", err)
return latestMigrationRan, migrationsRun, nil
func parseMigrationFileName(filename string) (int, string, error) {
sp := strings.SplitN(filename, "-", 2)
i, err := strconv.Atoi(sp[0])
if err != nil {
return 0, "", err
tr := strings.TrimSuffix(sp[1], ".sql")
return i, tr, nil