Skip to content

Commit b9adc6d

Browse files
committed
Better loader delimiter check
1 parent d727315 commit b9adc6d

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

load.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dbump
22

33
import (
4+
"fmt"
45
"io/fs"
56
"os"
67
"path/filepath"
@@ -10,7 +11,6 @@ import (
1011
)
1112

1213
type FS interface {
13-
fs.FS
1414
fs.ReadDirFS
1515
fs.ReadFileFS
1616
}
@@ -116,33 +116,35 @@ func loadMigrationFromFS(fsys FS, path, id, name string) (*Migration, error) {
116116
return nil, err
117117
}
118118

119-
m := parseMigration(body)
119+
m, err := parseMigration(body)
120+
if err != nil {
121+
return nil, err
122+
}
120123
m.ID = int(n)
121124
m.Name = name
122125
return m, nil
123126
}
124127

125-
func parseMigration(body []byte) *Migration {
126-
// TODO(oleg): get name from magic comment
127-
parts := strings.SplitN(string(body), MigrationDelimiter, 2)
128-
applySQL := strings.TrimSpace(parts[0])
128+
func parseMigration(body []byte) (*Migration, error) {
129+
parts := strings.Split(string(body), MigrationDelimiter)
129130

130-
var revertSQL string
131-
if len(parts) == 2 {
132-
revertSQL = strings.TrimSpace(parts[1])
131+
if size := len(parts); size != 2 {
132+
return nil, fmt.Errorf("should have 2 parts separated by MigrationDelimiter but got: %d", size)
133133
}
134+
applySQL := strings.TrimSpace(parts[0])
135+
revertSQL := strings.TrimSpace(parts[1])
134136

135137
return &Migration{
136138
Apply: applySQL,
137139
Revert: revertSQL,
138-
}
140+
}, nil
139141
}
140142

141143
type osFS struct{}
142144

143145
// Open implements dbump.FS interface.
144146
func (osFS) Open(name string) (fs.File, error) {
145-
return os.Open(name)
147+
panic("unreachable")
146148
}
147149

148150
// ReadDir implements dbump.FS interface.

load_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,9 @@ func TestSliceLoader(t *testing.T) {
7575
mustEqual(t, migs[i], want[i])
7676
}
7777
}
78+
79+
func TestBadFormat(t *testing.T) {
80+
loader := NewFileSysLoader(testdata, "testdata/bad")
81+
_, err := loader.Load()
82+
failIfOk(t, err)
83+
}

testdata/bad/0001_init.sql

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
CREATE TABLE ok_table (
2+
id INT
3+
)
4+
--- apply above AND revert below ---
5+
DROP TABLE ok_table;

0 commit comments

Comments
 (0)