diff --git a/db/migrations.go b/db/migrations.go new file mode 100644 index 0000000..7b9a665 --- /dev/null +++ b/db/migrations.go @@ -0,0 +1,165 @@ +// Migrations provides an agnostic way to run database migrations for a service. +// Just provide a database connection and FS full of migration scripts, and +// call PrepareDatabase() and RunMigrations() (both idempotent) and you're off +// to the races. +// +// Currently assumes MySQL, but will be expanded with more DB backends Soon™️. +package grabdb + +import ( + "context" + "database/sql" + "fmt" + "io/fs" + "strconv" + "strings" + "time" +) + +type Migrations struct { + Connection *sql.DB + VersionTable string + MigrationsFS fs.ReadDirFS +} + +type migration struct { + id int + name string + content []byte +} + +// NewMigrations returns a new Migration struct that provides db migration functions. +// fs must be a filesystem that has a folder of .sql files in the root, +// named in the following pattern: 00-name.sql, 01-name.sql, 02-name.sql... +// Migrations will be applied in numerical sort order. Migrations will only be applied once. +// Behavior of duplicated integers is undefined, but probably unpleasant. +func NewMigrations(connection *sql.DB, dbType string, tablename string, fs fs.ReadDirFS) (*Migrations, error) { + if connection == nil { + return nil, fmt.Errorf("db connection must be provided") + } + if tablename == "" { + return nil, fmt.Errorf("tablename must be provided") + } + if fs == nil { + return nil, fmt.Errorf("filesystem must be provided") + } + return &Migrations{ + Connection: connection, + VersionTable: tablename, + MigrationsFS: fs, + }, nil +} + +func (m *Migrations) PrepareDatabase(ctx context.Context) error { + if m.Connection == nil || m.VersionTable == "" { + return fmt.Errorf("uninitialized migration struct") + } + + tablecheck := `SELECT count(*) AS count + FROM information_schema.TABLES + WHERE TABLE_NAME = '` + m.VersionTable + `' + AND TABLE_SCHEMA in (SELECT DATABASE());` + tableschema := `CREATE TABLE ` + m.VersionTable + `( + id INT NOT NULL, + name VARCHAR(100) NOT NULL, + datetime DATE, + PRIMARY KEY (id))` + + var versionTableExists int + m.Connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists) + if versionTableExists != 0 { + return nil + } + _, err := m.Connection.ExecContext(ctx, tableschema) + return err +} + +func (m *Migrations) GetLatestMigration(ctx context.Context) (int, error) { + if m.Connection == nil || m.VersionTable == "" { + return 0, fmt.Errorf("uninitialized migration struct") + } + + var latestMigration int + err := m.Connection.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM "+m.VersionTable).Scan(&latestMigration) + return latestMigration, err +} + +func (m *Migrations) RunMigrations(ctx context.Context) (int, int, error) { + if m.Connection == nil || m.VersionTable == "" { + return 0, 0, fmt.Errorf("uninitialized migration struct") + } + + migrations := map[int]migration{} + dir, err := m.MigrationsFS.ReadDir("/") + if err != nil { + return 0, 0, fmt.Errorf("could not load migrations from fs directory: %w", err) + } + for f := range dir { + if dir[f].Type().IsRegular() { + mig := migration{} + id, name, err := parseMigrationFileName(dir[f].Name()) + if err != nil { + return 0, 0, fmt.Errorf("could not parse migration from fs directory: %w", err) + } + mig.id, mig.name = id, name + mig.content, err = fs.ReadFile(m.MigrationsFS, "/"+dir[f].Name()) + if err != nil { + return 0, 0, fmt.Errorf("could not load migration from fs: %w", err) + } + migrations[mig.id] = mig + } + } + + latestMigrationRan, err := m.GetLatestMigration(ctx) + if err != nil { + return 0, 0, fmt.Errorf("unable to determine most recent migration: %w", err) + } + + // exit if nothing to do (that is, there's no greater migration ID) + if _, ok := migrations[latestMigrationRan+1]; !ok { + return latestMigrationRan, 0, nil + } + + // loop over and apply migrations if required + tx, err := m.Connection.BeginTx(ctx, nil) + if err != nil { + return latestMigrationRan, 0, fmt.Errorf("unable to open transaction: %w", err) + } + migrationsRun := 0 + for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] { + mig := migrations[latestMigrationRan+1] + _, err := tx.ExecContext(ctx, string(mig.content)) + if err != nil { + nestederr := tx.Rollback() + if nestederr != nil { + return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, nestederr) + } + return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, err) + } + _, err = tx.ExecContext(ctx, "INSERT INTO "+m.VersionTable+" (id, name, datetime) VALUES (?, ?, ?)", mig.id, mig.name, time.Now()) + if err != nil { + nestederr := tx.Rollback() + if nestederr != nil { + return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, nestederr) + } + return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, err) + } + latestMigrationRan = latestMigrationRan + 1 + migrationsRun = migrationsRun + 1 + } + err = tx.Commit() + if err != nil { + return latestMigrationRan, migrationsRun, fmt.Errorf("failure committing transaction: %w", err) + } + return latestMigrationRan, migrationsRun, nil +} + +func parseMigrationFileName(filename string) (int, string, error) { + sp := strings.SplitN(filename, "-", 2) + i, err := strconv.Atoi(sp[0]) + if err != nil { + return 0, "", err + } + tr := strings.TrimSuffix(sp[1], ".sql") + return i, tr, nil +} diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..6b0cc43 --- /dev/null +++ b/http/http.go @@ -0,0 +1,42 @@ +package grabhttp + +import ( + "encoding/json" + "io/fs" + "net/http" +) + +type errorStruct struct { + Status string `json:"status"` + Error string `json:"error"` +} + +func StaticHandler(f fs.FS) http.Handler { + return http.FileServer(http.FS(f)) +} + +func WriteJSON(i interface{}, status int, w http.ResponseWriter, r *http.Request) { + b, err := json.Marshal(i) + if err != nil { + ErrorJSON(err, http.StatusInternalServerError, w, r) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + w.Write(b) + w.Write([]byte("\n")) +} + +func WriteHTML(b []byte, status int, w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(status) + w.Write(b) + w.Write([]byte("\n")) +} + +func ErrorJSON(err error, status int, w http.ResponseWriter, r *http.Request) { + WriteJSON(errorStruct{ + Status: "error", + Error: err.Error(), + }, status, w, r) +} diff --git a/http/http_test.go b/http/http_test.go new file mode 100644 index 0000000..fca6307 --- /dev/null +++ b/http/http_test.go @@ -0,0 +1,70 @@ +package grabhttp + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteJSON(t *testing.T) { + testcases := []struct { + input interface{} + code int + outcode int + output []byte + }{ + { + input: struct{}{}, + code: http.StatusOK, + output: []byte("{}\n"), + }, + { + input: struct{}{}, + code: http.StatusNotFound, + output: []byte("{}\n"), + }, + { + input: struct { + Foo string + }{ + Foo: "foo", + }, + code: http.StatusOK, + output: []byte(`{"Foo":"foo"}` + "\n"), + }, + { + input: struct { + Foo struct { + Bar string `json:"quuz"` + } + }{ + Foo: struct { + Bar string `json:"quuz"` + }{Bar: "foo"}, + }, + code: http.StatusOK, + output: []byte(`{"Foo":{"quuz":"foo"}}` + "\n"), + }, + { + input: struct{ C func() }{C: func() {}}, + code: http.StatusOK, + outcode: http.StatusInternalServerError, + output: []byte(`{"status":"error","error":"json: unsupported type: func()"}` + "\n"), + }, + } + for i, c := range testcases { + w := httptest.NewRecorder() + WriteJSON(c.input, c.code, w, httptest.NewRequest("", "/", nil)) + if w.Code != c.code && w.Code != c.outcode { + t.Logf("code mismatch on case %d", i) + t.Fail() + } + b, err := io.ReadAll(w.Body) + if err != nil || !bytes.Equal(b, c.output) { + t.Logf("failed: %s != %s", b, c.output) + t.Fail() + } + } +}