initial planning
This commit is contained in:
79
database/memory.go
Normal file
79
database/memory.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"git.yetaga.in/alazyreader/library/book"
|
||||
)
|
||||
|
||||
type Memory struct {
|
||||
lock sync.Mutex
|
||||
shelf []book.Book
|
||||
}
|
||||
|
||||
func (m *Memory) GetAllBooks(_ context.Context) ([]book.Book, error) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if m.shelf == nil {
|
||||
m.shelf = []book.Book{}
|
||||
}
|
||||
|
||||
return m.shelf, nil
|
||||
}
|
||||
|
||||
func (m *Memory) AddBook(_ context.Context, b *book.Book) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if m.shelf == nil {
|
||||
m.shelf = []book.Book{}
|
||||
}
|
||||
|
||||
m.shelf = append(m.shelf, *b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Memory) UpdateBook(_ context.Context, old, new *book.Book) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if m.shelf == nil {
|
||||
m.shelf = []book.Book{}
|
||||
return fmt.Errorf("book does not exist")
|
||||
}
|
||||
|
||||
if old.ID != new.ID {
|
||||
return fmt.Errorf("cannot change book ID")
|
||||
}
|
||||
|
||||
for i := range m.shelf {
|
||||
if m.shelf[i].ID == old.ID {
|
||||
m.shelf[i] = *new
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("book does not exist")
|
||||
}
|
||||
|
||||
func (m *Memory) DeleteBook(_ context.Context, b *book.Book) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if m.shelf == nil {
|
||||
m.shelf = []book.Book{}
|
||||
return fmt.Errorf("book does not exist")
|
||||
}
|
||||
|
||||
for i := range m.shelf {
|
||||
if m.shelf[i].ID == b.ID {
|
||||
// reorder slice to remove book quickly
|
||||
m.shelf[i] = m.shelf[len(m.shelf)-1]
|
||||
m.shelf = m.shelf[:len(m.shelf)-1]
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("book does not exist")
|
||||
}
|
0
database/migrations/mysql/01-init.sql
Normal file
0
database/migrations/mysql/01-init.sql
Normal file
211
database/mysql.go
Normal file
211
database/mysql.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.yetaga.in/alazyreader/library/book"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
//go:embed migrations/mysql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
type migration struct {
|
||||
id int
|
||||
name string
|
||||
content []byte
|
||||
}
|
||||
|
||||
type MySQL struct {
|
||||
connection *sql.DB
|
||||
versionTable string
|
||||
migrationsDirectory string
|
||||
}
|
||||
|
||||
func NewMySQLConnection(user, pass, host, port, db string) (*MySQL, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", user, pass, host, port, db) // what a strange syntax
|
||||
connection, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MySQL{
|
||||
connection: connection,
|
||||
versionTable: "migrations",
|
||||
migrationsDirectory: "/migrations/mysql",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MySQL) PrepareDatabase(ctx context.Context) error {
|
||||
if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" {
|
||||
return fmt.Errorf("uninitialized mysql client")
|
||||
}
|
||||
|
||||
tablecheck := `SELECT count(*) AS count
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_NAME = '` + m.versionTable + `'
|
||||
AND TABLE_SCHEMA in (SELECT DATABASE());`
|
||||
tableschema := `CREATE TABLE ` + m.versionTable + `(
|
||||
id INT NOT NULL,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
datetime DATE,
|
||||
PRIMARY KEY (id))`
|
||||
|
||||
var versionTableExists int
|
||||
m.connection.QueryRowContext(ctx, tablecheck).Scan(&versionTableExists)
|
||||
if versionTableExists != 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := m.connection.ExecContext(ctx, tableschema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *MySQL) GetLatestMigration(ctx context.Context) (int, error) {
|
||||
if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" {
|
||||
return 0, fmt.Errorf("uninitialized mysql client")
|
||||
}
|
||||
|
||||
var latestMigration int
|
||||
err := m.connection.QueryRowContext(ctx, "SELECT MAX(id) FROM "+m.versionTable).Scan(&latestMigration)
|
||||
return latestMigration, err
|
||||
}
|
||||
|
||||
func (m *MySQL) RunMigrations(ctx context.Context) (int, error) {
|
||||
if m.connection == nil || m.migrationsDirectory == "" || m.versionTable == "" {
|
||||
return 0, fmt.Errorf("uninitialized mysql client")
|
||||
}
|
||||
|
||||
var migrations map[int]migration
|
||||
dir, err := migrationsFS.ReadDir(m.migrationsDirectory)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
for f := range dir {
|
||||
if dir[f].Type().IsRegular() {
|
||||
mig := migration{}
|
||||
id, name, err := parseMigrationFileName(dir[f].Name())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
mig.id, mig.name = id, name
|
||||
mig.content, err = fs.ReadFile(migrationsFS, m.migrationsDirectory+"/"+dir[f].Name())
|
||||
migrations[mig.id] = mig
|
||||
}
|
||||
}
|
||||
|
||||
latestMigrationRan, err := m.GetLatestMigration(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// exit if nothing to do (that is, there's no greater migration ID)
|
||||
if _, ok := migrations[latestMigrationRan+1]; !ok {
|
||||
return latestMigrationRan, nil
|
||||
}
|
||||
|
||||
// loop over and apply migrations if required
|
||||
tx, err := m.connection.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return latestMigrationRan, err
|
||||
}
|
||||
for migrationsToRun := true; migrationsToRun; _, migrationsToRun = migrations[latestMigrationRan+1] {
|
||||
mig := migrations[latestMigrationRan+1]
|
||||
_, err := tx.ExecContext(ctx, string(mig.content))
|
||||
if err != nil {
|
||||
nestederr := tx.Rollback()
|
||||
if nestederr != nil {
|
||||
return latestMigrationRan, nestederr
|
||||
}
|
||||
return latestMigrationRan, err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "INSERT INTO "+m.versionTable+" (id, name, datetime) VALUES (?, ?, ?)", mig.id, mig.name, time.Now())
|
||||
if err != nil {
|
||||
nestederr := tx.Rollback()
|
||||
if nestederr != nil {
|
||||
return latestMigrationRan, nestederr
|
||||
}
|
||||
return latestMigrationRan, err
|
||||
}
|
||||
latestMigrationRan = latestMigrationRan + 1
|
||||
}
|
||||
err = tx.Commit()
|
||||
return latestMigrationRan, err
|
||||
}
|
||||
|
||||
func (m *MySQL) GetAllBooks(ctx context.Context) ([]book.Book, error) {
|
||||
if m.connection == nil {
|
||||
return nil, fmt.Errorf("uninitialized mysql client")
|
||||
}
|
||||
|
||||
books := []book.Book{}
|
||||
rows, err := m.connection.QueryContext(ctx, "SELECT id, title FROM books")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
b := book.Book{}
|
||||
err := rows.Scan(&b.ID, &b.Title)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
books = append(books, b)
|
||||
}
|
||||
|
||||
return books, nil
|
||||
}
|
||||
|
||||
func (m *MySQL) AddBook(ctx context.Context, b *book.Book) error {
|
||||
if m.connection == nil {
|
||||
return fmt.Errorf("uninitialized mysql client")
|
||||
}
|
||||
|
||||
res, err := m.connection.ExecContext(ctx, "INSERT INTO books (title) VALUES (?)", b.Title)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if i != 1 {
|
||||
return fmt.Errorf("unexpectedly updated more than one row: %d", i)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) UpdateBook(ctx context.Context, old, new *book.Book) error {
|
||||
if m.connection == nil {
|
||||
return fmt.Errorf("uninitialized mysql client")
|
||||
}
|
||||
|
||||
res, err := m.connection.ExecContext(ctx, "UPDATE books SET title=? WHERE id=?", new.Title, old.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if i != 1 {
|
||||
return fmt.Errorf("unexpectedly updated more than one row: %d", i)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseMigrationFileName(filename string) (int, string, error) {
|
||||
sp := strings.SplitN(filename, "-", 2)
|
||||
i, err := strconv.Atoi(sp[0])
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
tr := strings.TrimSuffix(sp[1], ".sql")
|
||||
return i, tr, nil
|
||||
}
|
Reference in New Issue
Block a user