184 lines
3.9 KiB
Go
184 lines
3.9 KiB
Go
package migrate
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
)
|
|
|
|
const defaultDir = "sql/migrations"
|
|
|
|
// Runner coordinates golang-migrate operations.
|
|
type Runner struct {
|
|
Dir string
|
|
}
|
|
|
|
// NewRunner creates a runner that reads migration files from dir. When dir is
|
|
// empty, the default directory under sql/migrations is used.
|
|
func NewRunner(dir string) *Runner {
|
|
if dir == "" {
|
|
dir = defaultDir
|
|
}
|
|
return &Runner{Dir: dir}
|
|
}
|
|
|
|
// Up executes all migrations that have not been applied yet. Each step logs
|
|
// its outcome to provide clear visibility.
|
|
func (r *Runner) Up(ctx context.Context, dsn string) error {
|
|
absDir, err := filepath.Abs(r.Dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
migrations, err := r.loadMigrations(absDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m, err := migrate.New(fmt.Sprintf("file://%s", absDir), dsn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closeMigrator(m)
|
|
|
|
currentVersion, dirty, err := m.Version()
|
|
if err != nil {
|
|
if errors.Is(err, migrate.ErrNilVersion) {
|
|
currentVersion = 0
|
|
} else {
|
|
return fmt.Errorf("fetch current version: %w", err)
|
|
}
|
|
}
|
|
|
|
if dirty {
|
|
return fmt.Errorf("database is in a dirty state at version %d; please fix manually", currentVersion)
|
|
}
|
|
|
|
applied := false
|
|
for _, migration := range migrations {
|
|
if migration.version <= currentVersion {
|
|
continue
|
|
}
|
|
|
|
fmt.Printf("→ Applying migration %s ...\n", migration.name)
|
|
if err := m.Migrate(migration.version); err != nil {
|
|
if errors.Is(err, migrate.ErrNoChange) {
|
|
fmt.Printf("✅ Migration %s already applied\n", migration.name)
|
|
continue
|
|
}
|
|
return fmt.Errorf("apply migration %s: %w", migration.name, err)
|
|
}
|
|
applied = true
|
|
fmt.Printf("✅ Migration %s applied\n", migration.name)
|
|
}
|
|
|
|
if !applied {
|
|
fmt.Println("✅ Database schema already up-to-date")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Version reports the current schema version tracked by golang-migrate.
|
|
func (r *Runner) Version(dsn string) (uint, bool, error) {
|
|
absDir, err := filepath.Abs(r.Dir)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
|
|
m, err := migrate.New(fmt.Sprintf("file://%s", absDir), dsn)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
defer closeMigrator(m)
|
|
|
|
version, dirty, err := m.Version()
|
|
if err != nil {
|
|
if errors.Is(err, migrate.ErrNilVersion) {
|
|
return 0, false, nil
|
|
}
|
|
return 0, false, err
|
|
}
|
|
|
|
return version, dirty, nil
|
|
}
|
|
|
|
// Reset drops the public schema before replaying all migrations.
|
|
func (r *Runner) Reset(ctx context.Context, dsn string) error {
|
|
db, err := openDB(ctx, dsn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer db.Close()
|
|
|
|
fmt.Println("⚠️ Dropping public schema (preserving pglogical)...")
|
|
if _, err := db.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
|
|
return fmt.Errorf("drop public schema: %w", err)
|
|
}
|
|
if _, err := db.ExecContext(ctx, "CREATE SCHEMA IF NOT EXISTS public"); err != nil {
|
|
return fmt.Errorf("recreate public schema: %w", err)
|
|
}
|
|
|
|
return r.Up(ctx, dsn)
|
|
}
|
|
|
|
func (r *Runner) loadMigrations(absDir string) ([]*migrationFile, error) {
|
|
entries, err := os.ReadDir(absDir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migrationMap := make(map[uint]*migrationFile)
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
name := entry.Name()
|
|
if !strings.HasSuffix(name, ".up.sql") {
|
|
continue
|
|
}
|
|
parts := strings.SplitN(name, "_", 2)
|
|
if len(parts) != 2 {
|
|
continue
|
|
}
|
|
version, err := strconv.ParseUint(parts[0], 10, 64)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
migrationMap[uint(version)] = &migrationFile{
|
|
version: uint(version),
|
|
name: name,
|
|
}
|
|
}
|
|
|
|
var migrations []*migrationFile
|
|
for _, m := range migrationMap {
|
|
migrations = append(migrations, m)
|
|
}
|
|
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].version < migrations[j].version
|
|
})
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
func closeMigrator(m *migrate.Migrate) {
|
|
if m == nil {
|
|
return
|
|
}
|
|
_, _ = m.Close()
|
|
}
|
|
|
|
type migrationFile struct {
|
|
version uint
|
|
name string
|
|
}
|