// Migrations provides an agnostic way to run database migrations for a service. // Just provide a database connection and FS full of migration scripts, and // call PrepareDatabase() and RunMigrations() (both idempotent) and you're off // to the races. // // Currently assumes MySQL, but will be expanded with more DB backends Soon™️. package grabdb import ( "context" "database/sql" "fmt" "io/fs" "strconv" "strings" "time" ) type Migrations struct { Connection *sql.DB VersionTable string MigrationsFS fs.ReadDirFS } type migration struct { id int name string content []byte } // NewMigrations returns a new Migration struct that provides db migration functions. // fs must be a filesystem that has a folder of .sql files in the root, // named in the following pattern: 00-name.sql, 01-name.sql, 02-name.sql... // Migrations will be applied in numerical sort order. Migrations will only be applied once. // Behavior of duplicated integers is undefined, but probably unpleasant. func NewMigrations(connection *sql.DB, dbType string, tablename string, fs fs.ReadDirFS) (*Migrations, error) { if connection == nil { return nil, fmt.Errorf("db connection must be provided") } if tablename == "" { return nil, fmt.Errorf("tablename must be provided") } if fs == nil { return nil, fmt.Errorf("filesystem must be provided") } return &Migrations{ Connection: connection, VersionTable: tablename, MigrationsFS: fs, }, nil } func (m *Migrations) PrepareDatabase(ctx context.Context) error { if m.Connection == nil || m.VersionTable == "" { return fmt.Errorf("uninitialized migration struct") } tablecheck := `SELECT count(*) AS count FROM information_schema.TABLES WHERE TABLE_NAME = '` + m.VersionTable + `' AND TABLE_SCHEMA in (SELECT DATABASE());` tableschema := `CREATE TABLE ` + m.VersionTable + `( id INT NOT NULL, name VARCHAR(100) NOT NULL, datetime DATE, PRIMARY KEY (id))` var versionTableExists int m.Connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists) if versionTableExists != 0 { return nil } _, err := m.Connection.ExecContext(ctx, tableschema) return err } func (m *Migrations) GetLatestMigration(ctx context.Context) (int, error) { if m.Connection == nil || m.VersionTable == "" { return 0, fmt.Errorf("uninitialized migration struct") } var latestMigration int err := m.Connection.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM "+m.VersionTable).Scan(&latestMigration) return latestMigration, err } func (m *Migrations) RunMigrations(ctx context.Context) (int, int, error) { if m.Connection == nil || m.VersionTable == "" { return 0, 0, fmt.Errorf("uninitialized migration struct") } migrations := map[int]migration{} dir, err := m.MigrationsFS.ReadDir("/") if err != nil { return 0, 0, fmt.Errorf("could not load migrations from fs directory: %w", err) } for f := range dir { if dir[f].Type().IsRegular() { mig := migration{} id, name, err := parseMigrationFileName(dir[f].Name()) if err != nil { return 0, 0, fmt.Errorf("could not parse migration from fs directory: %w", err) } mig.id, mig.name = id, name mig.content, err = fs.ReadFile(m.MigrationsFS, "/"+dir[f].Name()) if err != nil { return 0, 0, fmt.Errorf("could not load migration from fs: %w", err) } migrations[mig.id] = mig } } latestMigrationRan, err := m.GetLatestMigration(ctx) if err != nil { return 0, 0, fmt.Errorf("unable to determine most recent migration: %w", err) } // exit if nothing to do (that is, there's no greater migration ID) if _, ok := migrations[latestMigrationRan+1]; !ok { return latestMigrationRan, 0, nil } // loop over and apply migrations if required tx, err := m.Connection.BeginTx(ctx, nil) if err != nil { return latestMigrationRan, 0, fmt.Errorf("unable to open transaction: %w", err) } migrationsRun := 0 for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] { mig := migrations[latestMigrationRan+1] _, err := tx.ExecContext(ctx, string(mig.content)) if err != nil { nestederr := tx.Rollback() if nestederr != nil { return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, nestederr) } return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, err) } _, err = tx.ExecContext(ctx, "INSERT INTO "+m.VersionTable+" (id, name, datetime) VALUES (?, ?, ?)", mig.id, mig.name, time.Now()) if err != nil { nestederr := tx.Rollback() if nestederr != nil { return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, nestederr) } return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, err) } latestMigrationRan = latestMigrationRan + 1 migrationsRun = migrationsRun + 1 } err = tx.Commit() if err != nil { return latestMigrationRan, migrationsRun, fmt.Errorf("failure committing transaction: %w", err) } return latestMigrationRan, migrationsRun, nil } func parseMigrationFileName(filename string) (int, string, error) { sp := strings.SplitN(filename, "-", 2) i, err := strconv.Atoi(sp[0]) if err != nil { return 0, "", err } tr := strings.TrimSuffix(sp[1], ".sql") return i, tr, nil }