library/database/mysql.go

212 lines
5.2 KiB
Go

package database
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"strconv"
"strings"
"time"
"git.yetaga.in/alazyreader/library/book"
_ "github.com/go-sql-driver/mysql"
)
//go:embed migrations/mysql
var migrationsFS embed.FS
type migration struct {
id int
name string
content []byte
}
type MySQL struct {
connection *sql.DB
versionTable string
migrationsDirectory string
}
func NewMySQLConnection(user, pass, host, port, db string) (*MySQL, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", user, pass, host, port, db) // what a strange syntax
connection, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
return &MySQL{
connection: connection,
versionTable: "migrations",
migrationsDirectory: "/migrations/mysql",
}, nil
}
func (m *MySQL) PrepareDatabase(ctx context.Context) error {
if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" {
return fmt.Errorf("uninitialized mysql client")
}
tablecheck := `SELECT count(*) AS count
FROM information_schema.TABLES
WHERE TABLE_NAME = '` + m.versionTable + `'
AND TABLE_SCHEMA in (SELECT DATABASE());`
tableschema := `CREATE TABLE ` + m.versionTable + `(
id INT NOT NULL,
name VARCHAR(100) NOT NULL,
datetime DATE,
PRIMARY KEY (id))`
var versionTableExists int
m.connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists)
if versionTableExists != 0 {
return nil
}
_, err := m.connection.ExecContext(ctx, tableschema)
return err
}
func (m *MySQL) GetLatestMigration(ctx context.Context) (int, error) {
if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" {
return 0, fmt.Errorf("uninitialized mysql client")
}
var latestMigration int
err := m.connection.QueryRowContext(ctx, "SELECT MAX(id) FROM "+m.versionTable).Scan(&latestMigration)
return latestMigration, err
}
func (m *MySQL) RunMigrations(ctx context.Context) (int, error) {
if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" {
return 0, fmt.Errorf("uninitialized mysql client")
}
var migrations map[int]migration
dir, err := migrationsFS.ReadDir(m.migrationsDirectory)
if err != nil {
return 0, nil
}
for f := range dir {
if dir[f].Type().IsRegular() {
mig := migration{}
id, name, err := parseMigrationFileName(dir[f].Name())
if err != nil {
return 0, err
}
mig.id, mig.name = id, name
mig.content, err = fs.ReadFile(migrationsFS, m.migrationsDirectory+"/"+dir[f].Name())
migrations[mig.id] = mig
}
}
latestMigrationRan, err := m.GetLatestMigration(ctx)
if err != nil {
return 0, err
}
// exit if nothing to do (that is, there's no greater migration ID)
if _, ok := migrations[latestMigrationRan+1]; !ok {
return latestMigrationRan, nil
}
// loop over and apply migrations if required
tx, err := m.connection.BeginTx(ctx, nil)
if err != nil {
return latestMigrationRan, err
}
for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] {
mig := migrations[latestMigrationRan+1]
_, err := tx.ExecContext(ctx, string(mig.content))
if err != nil {
nestederr := tx.Rollback()
if nestederr != nil {
return latestMigrationRan, nestederr
}
return latestMigrationRan, err
}
_, err = tx.ExecContext(ctx, "INSERT INTO "+m.versionTable+" (id, name, datetime) VALUES (?, ?, ?)", mig.id, mig.name, time.Now())
if err != nil {
nestederr := tx.Rollback()
if nestederr != nil {
return latestMigrationRan, nestederr
}
return latestMigrationRan, err
}
latestMigrationRan = latestMigrationRan + 1
}
err = tx.Commit()
return latestMigrationRan, err
}
func (m *MySQL) GetAllBooks(ctx context.Context) ([]book.Book, error) {
if m.connection == nil {
return nil, fmt.Errorf("uninitialized mysql client")
}
books := []book.Book{}
rows, err := m.connection.QueryContext(ctx, "SELECT id, title FROM books")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
b := book.Book{}
err := rows.Scan(&b.ID, &b.Title)
if err != nil {
return nil, err
}
books = append(books, b)
}
return books, nil
}
func (m *MySQL) AddBook(ctx context.Context, b *book.Book) error {
if m.connection == nil {
return fmt.Errorf("uninitialized mysql client")
}
res, err := m.connection.ExecContext(ctx, "INSERT INTO books (title) VALUES (?)", b.Title)
if err != nil {
return err
}
i, err := res.RowsAffected()
if err != nil {
return err
}
if i != 1 {
return fmt.Errorf("unexpectedly updated more than one row: %d", i)
}
return nil
}
func (m *MySQL) UpdateBook(ctx context.Context, old, new *book.Book) error {
if m.connection == nil {
return fmt.Errorf("uninitialized mysql client")
}
res, err := m.connection.ExecContext(ctx, "UPDATE books SET title=? WHERE id=?", new.Title, old.ID)
if err != nil {
return err
}
i, err := res.RowsAffected()
if err != nil {
return err
}
if i != 1 {
return fmt.Errorf("unexpectedly updated more than one row: %d", i)
}
return nil
}
func parseMigrationFileName(filename string) (int, string, error) {
sp := strings.SplitN(filename, "-", 2)
i, err := strconv.Atoi(sp[0])
if err != nil {
return 0, "", err
}
tr := strings.TrimSuffix(sp[1], ".sql")
return i, tr, nil
}