package database import ( "context" "database/sql" "embed" "fmt" "io/fs" "strconv" "strings" "time" "git.yetaga.in/alazyreader/library/media" _ "github.com/go-sql-driver/mysql" ) //go:embed migrations/mysql var migrationsFS embed.FS type migration struct { id int name string content []byte } type MySQL struct { connection *sql.DB tableName string versionTable string migrationsDirectory string } func NewMySQLConnection(user, pass, host, port, db string) (*MySQL, error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", user, pass, host, port, db) // what a strange syntax connection, err := sql.Open("mysql", dsn) if err != nil { return nil, err } return &MySQL{ connection: connection, tableName: "books", versionTable: "migrations", migrationsDirectory: "migrations/mysql", }, nil } func (m *MySQL) PrepareDatabase(ctx context.Context) error { if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" { return fmt.Errorf("uninitialized mysql client") } tablecheck := fmt.Sprintf(`SELECT count(*) AS count FROM information_schema.TABLES WHERE TABLE_NAME = '%s' AND TABLE_SCHEMA in (SELECT DATABASE());`, m.versionTable) tableschema := fmt.Sprintf(`CREATE TABLE %s ( id INT NOT NULL, name VARCHAR(100) NOT NULL, datetime DATE, PRIMARY KEY (id))`, m.versionTable) var versionTableExists int m.connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists) if versionTableExists != 0 { return nil } _, err := m.connection.ExecContext(ctx, tableschema) return err } func (m *MySQL) GetLatestMigration(ctx context.Context) (int, error) { if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" { return 0, fmt.Errorf("uninitialized mysql client") } migrationCheck := fmt.Sprintf("SELECT COALESCE(MAX(id), 0) FROM %s", m.versionTable) var latestMigration int err := m.connection.QueryRowContext(ctx, migrationCheck).Scan(&latestMigration) return latestMigration, err } func (m *MySQL) RunMigrations(ctx context.Context) (int, int, error) { if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" { return 0, 0, fmt.Errorf("uninitialized mysql client") } migrations := map[int]migration{} dir, err := migrationsFS.ReadDir(m.migrationsDirectory) if err != nil { return 0, 0, err } for f := range dir { if dir[f].Type().IsRegular() { mig := migration{} id, name, err := parseMigrationFileName(dir[f].Name()) if err != nil { return 0, 0, err } mig.id, mig.name = id, name mig.content, err = fs.ReadFile(migrationsFS, m.migrationsDirectory+"/"+dir[f].Name()) if err != nil { return 0, 0, fmt.Errorf("failure loading migration: %w", err) } migrations[mig.id] = mig } } latestMigrationRan, err := m.GetLatestMigration(ctx) if err != nil { return 0, 0, err } // exit if nothing to do (that is, there's no greater migration ID) if _, ok := migrations[latestMigrationRan+1]; !ok { return latestMigrationRan, 0, nil } // loop over and apply migrations if required tx, err := m.connection.BeginTx(ctx, nil) if err != nil { return latestMigrationRan, 0, err } migrationLogSql := fmt.Sprintf("INSERT INTO %s (id, name, datetime) VALUES (?, ?, ?)", m.versionTable) migrationsRun := 0 for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] { mig := migrations[latestMigrationRan+1] _, err := tx.ExecContext(ctx, string(mig.content)) if err != nil { nestederr := tx.Rollback() if nestederr != nil { return latestMigrationRan, migrationsRun, nestederr } return latestMigrationRan, migrationsRun, err } _, err = tx.ExecContext(ctx, migrationLogSql, mig.id, mig.name, time.Now()) if err != nil { nestederr := tx.Rollback() if nestederr != nil { return latestMigrationRan, migrationsRun, nestederr } return latestMigrationRan, migrationsRun, err } latestMigrationRan = latestMigrationRan + 1 migrationsRun = migrationsRun + 1 } err = tx.Commit() return latestMigrationRan, migrationsRun, err } func (m *MySQL) GetAllBooks(ctx context.Context) ([]media.Book, error) { if m.connection == nil { return nil, fmt.Errorf("uninitialized mysql client") } allBooksQuery := fmt.Sprintf(`SELECT id, title, authors, sortauthor, isbn10, isbn13, format, genre, publisher, series, volume, year, signed, description, notes, coverurl FROM %s`, m.tableName) books := []media.Book{} rows, err := m.connection.QueryContext(ctx, allBooksQuery) if err != nil { return nil, err } defer rows.Close() for rows.Next() { b := media.Book{} var authors string err := rows.Scan( &b.ID, &b.Title, &authors, &b.SortAuthor, &b.ISBN10, &b.ISBN13, &b.Format, &b.Genre, &b.Publisher, &b.Series, &b.Volume, &b.Year, &b.Signed, &b.Description, &b.Notes, &b.CoverURL) if err != nil { return nil, err } b.Authors = strings.Split(authors, ";") books = append(books, b) } return books, nil } func (m *MySQL) AddBook(ctx context.Context, b *media.Book) error { if m.connection == nil { return fmt.Errorf("uninitialized mysql client") } res, err := m.connection.ExecContext(ctx, ` INSERT INTO `+m.tableName+` (title, authors, sortauthor, isbn10, isbn13, format, genre, publisher, series, volume, year, signed, description, notes, coverurl) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, b.Title, strings.Join(b.Authors, ";"), b.SortAuthor, b.ISBN10, b.ISBN13, b.Format, b.Genre, b.Publisher, b.Series, b.Volume, b.Year, b.Signed, b.Description, b.Notes, b.CoverURL, ) if err != nil { return err } i, err := res.RowsAffected() if err != nil { return err } if i != 1 { return fmt.Errorf("unexpectedly updated more than one row: %d", i) } return nil } func (m *MySQL) UpdateBook(ctx context.Context, old, new *media.Book) error { if m.connection == nil { return fmt.Errorf("uninitialized mysql client") } if old.ID != new.ID { return fmt.Errorf("cannot change book ID") } res, err := m.connection.ExecContext(ctx, ` UPDATE `+m.tableName+` SET id=? title=? authors=? sortauthor=? isbn10=? isbn13=? format=? genre=? publisher=? series=? volume=? year=? signed=? description=? notes=? coverurl=? WHERE id=?`, new.Title, strings.Join(new.Authors, ";"), new.SortAuthor, new.ISBN10, new.ISBN13, new.Format, new.Genre, new.Publisher, new.Series, new.Volume, new.Year, new.Signed, new.Description, new.Notes, new.CoverURL, old.ID) if err != nil { return err } i, err := res.RowsAffected() if err != nil { return err } if i != 1 { return fmt.Errorf("unexpectedly updated more than one row: %d", i) } return nil } func parseMigrationFileName(filename string) (int, string, error) { sp := strings.SplitN(filename, "-", 2) i, err := strconv.Atoi(sp[0]) if err != nil { return 0, "", err } tr := strings.TrimSuffix(sp[1], ".sql") return i, tr, nil }