// Portions of this file are derived from https://github.com/go-gormigrate/gormigrate // MIT License // Copyright (c) 2016 Andrey Nering // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE package migrator import ( "context" "errors" "fmt" "reflect" "time" "gorm.io/gorm" ) const ( initSchemaMigrationID = "SCHEMA_INIT" ) // MigrateFunc is the func signature for migrating. type MigrateFunc func(*gorm.DB) error // RollbackFunc is the func signature for rollbacking. type RollbackFunc func(*gorm.DB) error // InitSchemaFunc is the func signature for initializing the schema. type InitSchemaFunc func(*gorm.DB) error // Options define options for all migrations. type Options struct { // TableName is the migration table. TableName string // IDColumnName is the name of column where the migration id will be stored. IDColumnName string // IDColumnSize is the length of the migration id column IDColumnSize int // SequenceColumnName is the name of the auto-incrementing numeric column. SequenceColumnName string // AppliedAtColumnName is the name of the column storing when the migration was applied. AppliedAtColumnName string // StatusColumnName is the name of the column storing the migration status (success/failure). StatusColumnName string // UseTransaction makes Gormigrate execute migrations inside a single transaction. // Keep in mind that not all databases support DDL commands inside transactions. UseTransaction bool // ValidateUnknownMigrations will cause migrate to fail if there's unknown migration // IDs in the database ValidateUnknownMigrations bool } // Migration represents a database migration (a modification to be made on the database). type Migration struct { // ID is the migration identifier. Usually a timestamp like "201601021504". ID string // Migrate is a function that will br executed while running this migration. Migrate MigrateFunc // Rollback will be executed on rollback. Can be nil. Rollback RollbackFunc } // Gormigrate represents a collection of all migrations of a database schema. type Gormigrate struct { db *gorm.DB tx *gorm.DB options *Options migrations []*Migration initSchema InitSchemaFunc } // ReservedIDError is returned when a migration is using a reserved ID type ReservedIDError struct { ID string } func (e *ReservedIDError) Error() string { return fmt.Sprintf(`gormigrate: Reserved migration ID: "%s"`, e.ID) } // DuplicatedIDError is returned when more than one migration have the same ID type DuplicatedIDError struct { ID string } func (e *DuplicatedIDError) Error() string { return fmt.Sprintf(`gormigrate: Duplicated migration ID: "%s"`, e.ID) } var ( // DefaultOptions can be used if you don't want to think about options. DefaultOptions = &Options{ TableName: "migrations", IDColumnName: "id", IDColumnSize: 255, SequenceColumnName: "sequence", AppliedAtColumnName: "applied_at", StatusColumnName: "status", UseTransaction: true, ValidateUnknownMigrations: false, } // ErrRollbackImpossible is returned when trying to rollback a migration // that has no rollback function. ErrRollbackImpossible = errors.New("gormigrate: It's impossible to rollback this migration") // ErrNoMigrationDefined is returned when no migration is defined. ErrNoMigrationDefined = errors.New("gormigrate: No migration defined") // ErrMissingID is returned when the ID od migration is equal to "" ErrMissingID = errors.New("gormigrate: Missing ID in migration") // ErrNoRunMigration is returned when any run migration was found while // running RollbackLast ErrNoRunMigration = errors.New("gormigrate: Could not find last run migration") // ErrMigrationIDDoesNotExist is returned when migrating or rolling back to a migration ID that // does not exist in the list of migrations ErrMigrationIDDoesNotExist = errors.New("gormigrate: Tried to migrate to an ID that doesn't exist") // ErrUnknownPastMigration is returned if a migration exists in the DB that doesn't exist in the code ErrUnknownPastMigration = errors.New("gormigrate: Found migration in DB that does not exist in code") ) // New returns a new Gormigrate. func New(db *gorm.DB, options *Options, migrations []*Migration) *Gormigrate { if options == nil { options = DefaultOptions } if options.TableName == "" { options.TableName = DefaultOptions.TableName } if options.IDColumnName == "" { options.IDColumnName = DefaultOptions.IDColumnName } if options.IDColumnSize == 0 { options.IDColumnSize = DefaultOptions.IDColumnSize } if options.SequenceColumnName == "" { options.SequenceColumnName = DefaultOptions.SequenceColumnName } if options.AppliedAtColumnName == "" { options.AppliedAtColumnName = DefaultOptions.AppliedAtColumnName } if options.StatusColumnName == "" { options.StatusColumnName = DefaultOptions.StatusColumnName } return &Gormigrate{ db: db, options: options, migrations: migrations, } } // InitSchema sets a function that is run if no migration is found. // The idea is preventing to run all migrations when a new clean database // is being migrating. In this function you should create all tables and // foreign key necessary to your application. func (g *Gormigrate) InitSchema(initSchema InitSchemaFunc) { g.initSchema = initSchema } // Migrate executes all migrations that did not run yet. func (g *Gormigrate) Migrate() error { if !g.hasMigrations() { return ErrNoMigrationDefined } var targetMigrationID string if len(g.migrations) > 0 { targetMigrationID = g.migrations[len(g.migrations)-1].ID } return g.migrate(targetMigrationID) } // MigrateTo executes all migrations that did not run yet up to the migration that matches `migrationID`. func (g *Gormigrate) MigrateTo(migrationID string) error { if err := g.checkIDExist(migrationID); err != nil { return err } return g.migrate(migrationID) } func (g *Gormigrate) migrate(migrationID string) error { if !g.hasMigrations() { return ErrNoMigrationDefined } if err := g.checkReservedID(); err != nil { return err } if err := g.checkDuplicatedID(); err != nil { return err } g.begin() defer g.rollback() if err := g.createMigrationTableIfNotExists(); err != nil { return err } if g.options.ValidateUnknownMigrations { unknownMigrations, err := g.unknownMigrationsHaveHappened() if err != nil { return err } if unknownMigrations { return ErrUnknownPastMigration } } if g.initSchema != nil { canInitializeSchema, err := g.canInitializeSchema() if err != nil { return err } if canInitializeSchema { if err := g.runInitSchema(); err != nil { return err } return g.commit() } } for _, migration := range g.migrations { if err := g.runMigration(migration); err != nil { return err } if migrationID != "" && migration.ID == migrationID { break } } return g.commit() } // There are migrations to apply if either there's a defined // initSchema function or if the list of migrations is not empty. func (g *Gormigrate) hasMigrations() bool { return g.initSchema != nil || len(g.migrations) > 0 } // Check whether any migration is using a reserved ID. // For now there's only have one reserved ID, but there may be more in the future. func (g *Gormigrate) checkReservedID() error { for _, m := range g.migrations { if m.ID == initSchemaMigrationID { return &ReservedIDError{ID: m.ID} } } return nil } func (g *Gormigrate) checkDuplicatedID() error { lookup := make(map[string]struct{}, len(g.migrations)) for _, m := range g.migrations { if _, ok := lookup[m.ID]; ok { return &DuplicatedIDError{ID: m.ID} } lookup[m.ID] = struct{}{} } return nil } func (g *Gormigrate) checkIDExist(migrationID string) error { for _, migrate := range g.migrations { if migrate.ID == migrationID { return nil } } return ErrMigrationIDDoesNotExist } // RollbackLast undo the last migration func (g *Gormigrate) RollbackLast() error { if len(g.migrations) == 0 { return ErrNoMigrationDefined } g.begin() defer g.rollback() lastRunMigration, err := g.getLastRunMigration() if err != nil { return err } if err := g.rollbackMigration(lastRunMigration); err != nil { return err } return g.commit() } // RollbackTo undoes migrations up to the given migration that matches the `migrationID`. // Migration with the matching `migrationID` is not rolled back. func (g *Gormigrate) RollbackTo(migrationID string) error { if len(g.migrations) == 0 { return ErrNoMigrationDefined } if err := g.checkIDExist(migrationID); err != nil { return err } g.begin() defer g.rollback() for i := len(g.migrations) - 1; i >= 0; i-- { migration := g.migrations[i] if migration.ID == migrationID { break } migrationRan, err := g.migrationRan(migration) if err != nil { return err } if migrationRan { if err := g.rollbackMigration(migration); err != nil { return err } } } return g.commit() } func (g *Gormigrate) getLastRunMigration() (*Migration, error) { for i := len(g.migrations) - 1; i >= 0; i-- { migration := g.migrations[i] migrationRan, err := g.migrationRan(migration) if err != nil { return nil, err } if migrationRan { return migration, nil } } return nil, ErrNoRunMigration } // RollbackMigration undo a migration. func (g *Gormigrate) RollbackMigration(m *Migration) error { g.begin() defer g.rollback() if err := g.rollbackMigration(m); err != nil { return err } return g.commit() } func (g *Gormigrate) rollbackMigration(m *Migration) error { if m.Rollback == nil { return ErrRollbackImpossible } if err := m.Rollback(g.tx); err != nil { return err } cond := fmt.Sprintf("%s = ?", g.options.IDColumnName) return g.tx.Table(g.options.TableName).Where(cond, m.ID).Delete(g.model()).Error } func (g *Gormigrate) runInitSchema() error { if err := g.initSchema(g.tx); err != nil { return err } if err := g.insertMigration(initSchemaMigrationID); err != nil { return err } for _, migration := range g.migrations { if err := g.insertMigration(migration.ID); err != nil { return err } } return nil } func (g *Gormigrate) runMigration(migration *Migration) error { if len(migration.ID) == 0 { return ErrMissingID } migrationRan, err := g.migrationRan(migration) if err != nil { return err } if !migrationRan { if err := migration.Migrate(g.tx); err != nil { return err } if err := g.insertMigration(migration.ID); err != nil { return err } } return nil } // model returns pointer to dynamically created gorm migration model struct value func (g *Gormigrate) model() any { fields := []reflect.StructField{ { Name: "ID", Type: reflect.TypeOf(""), Tag: reflect.StructTag(fmt.Sprintf( `gorm:"primaryKey;column:%s;size:%d"`, g.options.IDColumnName, g.options.IDColumnSize, )), }, { Name: "Sequence", Type: reflect.TypeOf(int64(0)), Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s"`, g.options.SequenceColumnName)), }, { Name: "AppliedAt", Type: reflect.TypeOf(time.Time{}), Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s"`, g.options.AppliedAtColumnName)), }, { Name: "Status", Type: reflect.TypeOf(""), Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s;size:20"`, g.options.StatusColumnName)), }, } structType := reflect.StructOf(fields) structValue := reflect.New(structType).Elem() return structValue.Addr().Interface() } func (g *Gormigrate) createMigrationTableIfNotExists() error { if err := g.tx.Table(g.options.TableName).AutoMigrate(g.model()); err != nil { return err } return g.backfillMigrationMetadata() } // backfillMigrationMetadata populates sequence, applied_at, and status for // rows that predate the addition of these columns (all marked as success // with the same timestamp). Rows are sequenced by their natural insertion // order (rowid for SQLite, ctid for PostgreSQL) so that the sequence column // reflects the actual order migrations were originally applied. func (g *Gormigrate) backfillMigrationMetadata() error { var orderCol string switch g.tx.Dialector.Name() { case "sqlite": orderCol = "rowid" case "postgres": orderCol = "ctid" default: orderCol = g.options.IDColumnName } var ids []string err := g.tx.Table(g.options.TableName). Where(fmt.Sprintf("%s IS NULL OR %s = ''", g.options.StatusColumnName, g.options.StatusColumnName)). Order(orderCol). Pluck(g.options.IDColumnName, &ids).Error if err != nil { return err } if len(ids) == 0 { return nil } now := time.Now() var maxSeq int64 if err := g.tx.Table(g.options.TableName). Select(fmt.Sprintf("COALESCE(MAX(%s), 0)", g.options.SequenceColumnName)). Scan(&maxSeq).Error; err != nil { return err } for i, id := range ids { err := g.tx.Table(g.options.TableName). Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), id). Updates(map[string]interface{}{ g.options.SequenceColumnName: maxSeq + int64(i) + 1, g.options.AppliedAtColumnName: now, g.options.StatusColumnName: "success", }).Error if err != nil { return err } } return nil } func (g *Gormigrate) migrationRan(m *Migration) (bool, error) { var count int64 err := g.tx. Table(g.options.TableName). Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), m.ID). Count(&count). Error return count > 0, err } // The schema can be initialised only if it hasn't been initialised yet // and no other migration has been applied already. func (g *Gormigrate) canInitializeSchema() (bool, error) { migrationRan, err := g.migrationRan(&Migration{ID: initSchemaMigrationID}) if err != nil { return false, err } if migrationRan { return false, nil } // If the ID doesn't exist, we also want the list of migrations to be empty var count int64 err = g.tx. Table(g.options.TableName). Count(&count). Error return count == 0, err } func (g *Gormigrate) unknownMigrationsHaveHappened() (bool, error) { rows, err := g.tx.Table(g.options.TableName).Select(g.options.IDColumnName).Rows() if err != nil { return false, err } defer func() { if err := rows.Close(); err != nil { g.tx.Logger.Error(context.TODO(), err.Error()) } }() validIDSet := make(map[string]struct{}, len(g.migrations)+1) validIDSet[initSchemaMigrationID] = struct{}{} for _, migration := range g.migrations { validIDSet[migration.ID] = struct{}{} } for rows.Next() { var pastMigrationID string if err := rows.Scan(&pastMigrationID); err != nil { return false, err } if _, ok := validIDSet[pastMigrationID]; !ok { return true, nil } } return false, nil } func (g *Gormigrate) nextSequence() (int64, error) { var maxSeq int64 err := g.tx.Table(g.options.TableName). Select(fmt.Sprintf("COALESCE(MAX(%s), 0)", g.options.SequenceColumnName)). Scan(&maxSeq).Error if err != nil { return 0, err } return maxSeq + 1, nil } func (g *Gormigrate) insertMigration(id string) error { seq, err := g.nextSequence() if err != nil { return err } record := g.model() v := reflect.ValueOf(record).Elem() v.FieldByName("ID").SetString(id) v.FieldByName("Sequence").SetInt(seq) v.FieldByName("AppliedAt").Set(reflect.ValueOf(time.Now())) v.FieldByName("Status").SetString("success") return g.tx.Table(g.options.TableName).Create(record).Error } func (g *Gormigrate) begin() { if g.options.UseTransaction { g.tx = g.db.Begin() } else { g.tx = g.db } } func (g *Gormigrate) commit() error { if g.options.UseTransaction { return g.tx.Commit().Error } return nil } func (g *Gormigrate) rollback() { if g.options.UseTransaction { g.tx.Rollback() } }