begin noodling on a little module of common idioms I use in personal projects
This commit is contained in:
parent
c2972abeb0
commit
4d5136d992
165
db/migrations.go
Normal file
165
db/migrations.go
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
// Migrations provides an agnostic way to run database migrations for a service.
|
||||||
|
// Just provide a database connection and FS full of migration scripts, and
|
||||||
|
// call PrepareDatabase() and RunMigrations() (both idempotent) and you're off
|
||||||
|
// to the races.
|
||||||
|
//
|
||||||
|
// Currently assumes MySQL, but will be expanded with more DB backends Soon™️.
|
||||||
|
package grabdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Migrations struct {
|
||||||
|
Connection *sql.DB
|
||||||
|
VersionTable string
|
||||||
|
MigrationsFS fs.ReadDirFS
|
||||||
|
}
|
||||||
|
|
||||||
|
type migration struct {
|
||||||
|
id int
|
||||||
|
name string
|
||||||
|
content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMigrations returns a new Migration struct that provides db migration functions.
|
||||||
|
// fs must be a filesystem that has a folder of .sql files in the root,
|
||||||
|
// named in the following pattern: 00-name.sql, 01-name.sql, 02-name.sql...
|
||||||
|
// Migrations will be applied in numerical sort order. Migrations will only be applied once.
|
||||||
|
// Behavior of duplicated integers is undefined, but probably unpleasant.
|
||||||
|
func NewMigrations(connection *sql.DB, dbType string, tablename string, fs fs.ReadDirFS) (*Migrations, error) {
|
||||||
|
if connection == nil {
|
||||||
|
return nil, fmt.Errorf("db connection must be provided")
|
||||||
|
}
|
||||||
|
if tablename == "" {
|
||||||
|
return nil, fmt.Errorf("tablename must be provided")
|
||||||
|
}
|
||||||
|
if fs == nil {
|
||||||
|
return nil, fmt.Errorf("filesystem must be provided")
|
||||||
|
}
|
||||||
|
return &Migrations{
|
||||||
|
Connection: connection,
|
||||||
|
VersionTable: tablename,
|
||||||
|
MigrationsFS: fs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Migrations) PrepareDatabase(ctx context.Context) error {
|
||||||
|
if m.Connection == nil || m.VersionTable == "" {
|
||||||
|
return fmt.Errorf("uninitialized migration struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
tablecheck := `SELECT count(*) AS count
|
||||||
|
FROM information_schema.TABLES
|
||||||
|
WHERE TABLE_NAME = '` + m.VersionTable + `'
|
||||||
|
AND TABLE_SCHEMA in (SELECT DATABASE());`
|
||||||
|
tableschema := `CREATE TABLE ` + m.VersionTable + `(
|
||||||
|
id INT NOT NULL,
|
||||||
|
name VARCHAR(100) NOT NULL,
|
||||||
|
datetime DATE,
|
||||||
|
PRIMARY KEY (id))`
|
||||||
|
|
||||||
|
var versionTableExists int
|
||||||
|
m.Connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists)
|
||||||
|
if versionTableExists != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err := m.Connection.ExecContext(ctx, tableschema)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Migrations) GetLatestMigration(ctx context.Context) (int, error) {
|
||||||
|
if m.Connection == nil || m.VersionTable == "" {
|
||||||
|
return 0, fmt.Errorf("uninitialized migration struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
var latestMigration int
|
||||||
|
err := m.Connection.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM "+m.VersionTable).Scan(&latestMigration)
|
||||||
|
return latestMigration, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Migrations) RunMigrations(ctx context.Context) (int, int, error) {
|
||||||
|
if m.Connection == nil || m.VersionTable == "" {
|
||||||
|
return 0, 0, fmt.Errorf("uninitialized migration struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
migrations := map[int]migration{}
|
||||||
|
dir, err := m.MigrationsFS.ReadDir("/")
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("could not load migrations from fs directory: %w", err)
|
||||||
|
}
|
||||||
|
for f := range dir {
|
||||||
|
if dir[f].Type().IsRegular() {
|
||||||
|
mig := migration{}
|
||||||
|
id, name, err := parseMigrationFileName(dir[f].Name())
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("could not parse migration from fs directory: %w", err)
|
||||||
|
}
|
||||||
|
mig.id, mig.name = id, name
|
||||||
|
mig.content, err = fs.ReadFile(m.MigrationsFS, "/"+dir[f].Name())
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("could not load migration from fs: %w", err)
|
||||||
|
}
|
||||||
|
migrations[mig.id] = mig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
latestMigrationRan, err := m.GetLatestMigration(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("unable to determine most recent migration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exit if nothing to do (that is, there's no greater migration ID)
|
||||||
|
if _, ok := migrations[latestMigrationRan+1]; !ok {
|
||||||
|
return latestMigrationRan, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over and apply migrations if required
|
||||||
|
tx, err := m.Connection.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return latestMigrationRan, 0, fmt.Errorf("unable to open transaction: %w", err)
|
||||||
|
}
|
||||||
|
migrationsRun := 0
|
||||||
|
for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] {
|
||||||
|
mig := migrations[latestMigrationRan+1]
|
||||||
|
_, err := tx.ExecContext(ctx, string(mig.content))
|
||||||
|
if err != nil {
|
||||||
|
nestederr := tx.Rollback()
|
||||||
|
if nestederr != nil {
|
||||||
|
return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, nestederr)
|
||||||
|
}
|
||||||
|
return latestMigrationRan, migrationsRun, fmt.Errorf("error executing migration %d: %w", latestMigrationRan+1, err)
|
||||||
|
}
|
||||||
|
_, err = tx.ExecContext(ctx, "INSERT INTO "+m.VersionTable+" (id, name, datetime) VALUES (?, ?, ?)", mig.id, mig.name, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
nestederr := tx.Rollback()
|
||||||
|
if nestederr != nil {
|
||||||
|
return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, nestederr)
|
||||||
|
}
|
||||||
|
return latestMigrationRan, migrationsRun, fmt.Errorf("failure recording migration in '%s': %w", m.VersionTable, err)
|
||||||
|
}
|
||||||
|
latestMigrationRan = latestMigrationRan + 1
|
||||||
|
migrationsRun = migrationsRun + 1
|
||||||
|
}
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return latestMigrationRan, migrationsRun, fmt.Errorf("failure committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
return latestMigrationRan, migrationsRun, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMigrationFileName(filename string) (int, string, error) {
|
||||||
|
sp := strings.SplitN(filename, "-", 2)
|
||||||
|
i, err := strconv.Atoi(sp[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", err
|
||||||
|
}
|
||||||
|
tr := strings.TrimSuffix(sp[1], ".sql")
|
||||||
|
return i, tr, nil
|
||||||
|
}
|
42
http/http.go
Normal file
42
http/http.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package grabhttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type errorStruct struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func StaticHandler(f fs.FS) http.Handler {
|
||||||
|
return http.FileServer(http.FS(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteJSON(i interface{}, status int, w http.ResponseWriter, r *http.Request) {
|
||||||
|
b, err := json.Marshal(i)
|
||||||
|
if err != nil {
|
||||||
|
ErrorJSON(err, http.StatusInternalServerError, w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write(b)
|
||||||
|
w.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteHTML(b []byte, status int, w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write(b)
|
||||||
|
w.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorJSON(err error, status int, w http.ResponseWriter, r *http.Request) {
|
||||||
|
WriteJSON(errorStruct{
|
||||||
|
Status: "error",
|
||||||
|
Error: err.Error(),
|
||||||
|
}, status, w, r)
|
||||||
|
}
|
70
http/http_test.go
Normal file
70
http/http_test.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package grabhttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteJSON(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
input interface{}
|
||||||
|
code int
|
||||||
|
outcode int
|
||||||
|
output []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: struct{}{},
|
||||||
|
code: http.StatusOK,
|
||||||
|
output: []byte("{}\n"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: struct{}{},
|
||||||
|
code: http.StatusNotFound,
|
||||||
|
output: []byte("{}\n"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: struct {
|
||||||
|
Foo string
|
||||||
|
}{
|
||||||
|
Foo: "foo",
|
||||||
|
},
|
||||||
|
code: http.StatusOK,
|
||||||
|
output: []byte(`{"Foo":"foo"}` + "\n"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: struct {
|
||||||
|
Foo struct {
|
||||||
|
Bar string `json:"quuz"`
|
||||||
|
}
|
||||||
|
}{
|
||||||
|
Foo: struct {
|
||||||
|
Bar string `json:"quuz"`
|
||||||
|
}{Bar: "foo"},
|
||||||
|
},
|
||||||
|
code: http.StatusOK,
|
||||||
|
output: []byte(`{"Foo":{"quuz":"foo"}}` + "\n"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: struct{ C func() }{C: func() {}},
|
||||||
|
code: http.StatusOK,
|
||||||
|
outcode: http.StatusInternalServerError,
|
||||||
|
output: []byte(`{"status":"error","error":"json: unsupported type: func()"}` + "\n"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, c := range testcases {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
WriteJSON(c.input, c.code, w, httptest.NewRequest("", "/", nil))
|
||||||
|
if w.Code != c.code && w.Code != c.outcode {
|
||||||
|
t.Logf("code mismatch on case %d", i)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
b, err := io.ReadAll(w.Body)
|
||||||
|
if err != nil || !bytes.Equal(b, c.output) {
|
||||||
|
t.Logf("failed: %s != %s", b, c.output)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user