Skip to content

Commit 883eb18

Browse files
committed
Add migration functions
1 parent eeca5c5 commit 883eb18

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

dbump.go

+26-8
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,20 @@ type Loader interface {
3131

3232
// Migration represents migration step that will be runned on DB.
3333
type Migration struct {
34-
ID int // ID of the migration, unique, positive, starts from 1.
35-
Name string // Name of the migration
36-
Apply string // Apply query
37-
Rollback string // Rollback query
34+
ID int // ID of the migration, unique, positive, starts from 1.
35+
Name string // Name of the migration
36+
Apply string // Apply query
37+
Rollback string // Rollback query
38+
ApplyFn MigrationFn // Apply func
39+
RollbackFn MigrationFn // Rollback func
40+
41+
isQuery bool // shortcut for the type of migration (query or func)
42+
}
43+
44+
type MigrationFn func(db DB) error
45+
46+
type DB interface {
47+
Exec(ctx context.Context, query string, args ...interface{}) error
3848
}
3949

4050
// Run the Migrator with migration queries provided by the Loader.
@@ -62,7 +72,10 @@ func loadMigrations(ms []*Migration, err error) ([]*Migration, error) {
6272
case m.ID > want:
6373
return nil, fmt.Errorf("missing migration number: %d (have %d)", want, m.ID)
6474
default:
65-
// pass
75+
if (m.Apply != "" || m.Rollback != "") && (m.ApplyFn != nil || m.RollbackFn != nil) {
76+
return nil, fmt.Errorf("mixing queries and functions is not allowed (migration %d)", m.ID)
77+
}
78+
m.isQuery = m.Apply != ""
6679
}
6780
}
6881
return ms, nil
@@ -115,15 +128,20 @@ func runMigrationExclusive(ctx context.Context, m Migrator, ms []*Migration) err
115128
for currentVersion != targetVersion {
116129
current := ms[currentVersion]
117130
sequence := current.ID
118-
query := current.Apply
131+
query, queryFn := current.Apply, current.ApplyFn
119132

120133
if direction == -1 {
121134
current = ms[currentVersion-1]
122135
sequence = current.ID - 1
123-
query = current.Rollback
136+
query, queryFn = current.Rollback, current.RollbackFn
124137
}
125138

126-
if err := m.Exec(ctx, query); err != nil {
139+
if current.isQuery {
140+
err = m.Exec(ctx, query)
141+
} else {
142+
err = queryFn(m)
143+
}
144+
if err != nil {
127145
return fmt.Errorf("exec: %w", err)
128146
}
129147

0 commit comments

Comments
 (0)