-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsquibble.go
404 lines (368 loc) · 13 KB
/
squibble.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package squibble provides a schema migration assistant for SQLite databases.
//
// # Overview
//
// A Schema value manages the schema of a SQLite database that will be modified
// over time. The current database schema is stored in the Current field, and
// migrations from previous versions are captured as UpdateRules.
//
// When the program starts up, it should pass the open database to the Apply
// method of the Schema. This verifies that the Schema is valid, then checks
// whether the database is up-to-date. If not, it applies any relevant update
// rules to bring it to the current state. If Apply fails, the database is
// rolled back.
//
// The Schema tracks schema versions by hashing the schema with SHA256, and it
// stores a record of upgrades in a _schema_history table that it maintains.
// Apply creates this table if it does not already exist, and updates it as
// update rules are applied.
//
// # Update Rules
//
// The Updates field of the Schema must contain an ordered list of update rules
// for all the versions of the schema prior to the Current one, from oldest to
// newest. Each rule has the hash of a previous schema version and a function
// that can be applied to the database to upgrade it to the next version in
// sequence.
//
// When revising the schema, you must add a new rule mapping the old (existing)
// schema to the new one. These rules are intended to be a permanent record of
// changes, and should be committed into source control as part of the
// program. As a consistency check, each rule must also declare the hash of the
// target schema it upgrades to.
//
// When Apply runs, it looks for the most recent version of the schema recorded
// in the _schema_history table. If there is none, and the database is
// otherwise empty, the current schema is assumed to be the initial version,
// and it is applied directly. Otherwise, Apply compares the hash of the most
// recent update to the current version: If they differ, it finds the most
// recent update hash in the Updates list, and applies all the updates from
// that point forward. If this succeeds, the current schema is recorded as the
// latest version in _schema_history.
//
// # Validation
//
// You use the Validate function to check that the current schema in the
// special sqlite_schema table maintained by SQLite matches a schema written as
// SQL text. If not, it reports a diff describing the differences between what
// the text wants and what the real schema has.
//
// # Limitations
//
// Currently this package only handles the main database, not attachments.
package squibble
import (
"bytes"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"time"
"github.com/klauspost/compress/zstd"
_ "embed"
)
const (
// historyTableName is the name of the history log table maintained by the
// Schema migrator in a database under its management. See history.sql.
historyTableName = "_schema_history"
queryHistoryRows = `SELECT timestamp, digest, schema FROM ` + historyTableName + ` ORDER BY timestamp`
queryHistoryInsert = `INSERT INTO ` + historyTableName + ` (timestamp, digest, schema) VALUES (?, ?, ?)`
)
//go:embed history.sql
var historyTableSchema string
// Schema defines a family of SQLite schema versions over time, expressed as a
// SQL definition of the current version of the schema, plus an ordered
// collection of upgrade rules that define how to update each version to the
// next.
type Schema struct {
// Current is the SQL definition of the most current version of the schema.
// It must not be empty.
Current string
// Updates is a sequence of schema update rules. The slice must contain an
// entry for each schema version prior to the newest.
Updates []UpdateRule
// Logf is where logs should be sent; the default is log.Printf.
Logf func(string, ...any)
}
// An UpdateRule defines a schema upgrade.
type UpdateRule struct {
// Source is the hex-encoded SHA256 digest of the schema at which this
// update applies. It must not be empty.
Source string
// Target is the hex-encoded SHA256 digest of the schema reached by applying
// this update. It must not be empty.
Target string
// Apply applies the necessary changes to update the schema to the next
// version in sequence. It must not be nil.
//
// An apply function can use squibble.Logf(ctx, ...) to write log messages
// to the logger defined by the associated Schema.
Apply func(ctx context.Context, db DBConn) error
}
func (s *Schema) logf(msg string, args ...any) {
if s == nil || s.Logf == nil {
log.Printf(msg, args...)
} else {
s.Logf(msg, args...)
}
}
type ctxSchemaKey struct{}
// Logf sends a log message to the logger attached to ctx, or to log.Printf if
// ctx does not have a logger attached. The context passed to the apply
// function of an UpdateRule will have this set to the logger for the Schema.
func Logf(ctx context.Context, msg string, args ...any) {
s, _ := ctx.Value(ctxSchemaKey{}).(*Schema)
s.logf(msg, args...)
}
// Apply applies any pending schema migrations to the given database. It
// reports an error immediately if s is not consistent (per Check); otherwise
// it creates a new transaction and attempts to apply all applicable upgrades
// to db within it. If this succeeds and the transaction commits successfully,
// then Apply succeeds. Otherwise, the transaction is rolled back and Apply
// reports the reason wny.
//
// When applying a schema to an existing unmanaged database, Apply reports an
// error if the current schema is not compatible with the existing schema;
// otherwise it applies the current schema and updates the history.
func (s *Schema) Apply(ctx context.Context, db *sql.DB) error {
if err := s.Check(); err != nil {
return err
}
s.logf("Checking schema version...")
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
// Stage 1: Create the schema versions table, if it does not exist.
// TODO(creachadair): Plumb an option for the table name.
if _, err := tx.ExecContext(ctx, historyTableSchema); err != nil {
return fmt.Errorf("create schema history: %w", err)
}
// Stage 2: Check whether the schema is up-to-date.
curHash, err := SQLDigest(s.Current)
if err != nil {
return err
}
latestHash, err := DBDigest(ctx, tx)
if err != nil {
return err
}
hr, err := History(ctx, tx)
if err != nil {
return fmt.Errorf("reading update history: %w", err)
} else if len(hr) == 0 {
// Case 1: There is no schema present in the history table.
if latestHash != curHash {
if !schemaIsEmpty(ctx, tx, "main") {
return errors.New("unmanaged schema already present")
}
if _, err := tx.ExecContext(ctx, s.Current); err != nil {
return fmt.Errorf("apply schema: %w", err)
}
s.logf("Initialized database with schema %s", curHash)
} else {
s.logf("Schema %s is already current; updating history", curHash)
}
if err := s.addVersion(ctx, tx, HistoryRow{
Timestamp: time.Now(),
Digest: curHash,
Schema: s.Current,
}); err != nil {
return err
}
return tx.Commit()
}
// Case 2: The current schema is up-to-date.
if latestHash == curHash {
s.logf("Schema is up-to-date at digest %s", curHash)
return nil
}
// Case 3: The current schema is not the latest. Apply pending changes.
last := hr[len(hr)-1]
s.logf("Last updated to %s at %s", last.Digest, last.Timestamp.Format(time.RFC3339Nano))
s.logf("Database schema: %s", latestHash)
s.logf("Target schema: %s", curHash)
// N.B. It is possible that a given schema will repeat in the history. In
// that case, however, it doesn't matter which one we start from: All the
// upgrades following ANY copy of that schema apply to all of them. We
// choose the last, just because it's less work if that happens.
i := s.firstPendingUpdate(latestHash)
if i < 0 {
return fmt.Errorf("no update found for digest %s (did you add an update rule?)", latestHash)
}
// Apply all the updates from the latest hash to the present.
s.logf("Applying %d pending schema upgrades", len(s.Updates)-i)
uctx := context.WithValue(ctx, ctxSchemaKey{}, s)
for j, update := range s.Updates[i:] {
if err := update.Apply(uctx, tx); err != nil {
return fmt.Errorf("update failed at digest %s: %w", update.Source, err)
}
conf, err := DBDigest(uctx, tx)
if err != nil {
return fmt.Errorf("confirming update: %w", err)
}
if conf != update.Target {
return fmt.Errorf("confirming update: got %s, want %s", conf, update.Target)
}
s.logf("[%d] updated to digest %s", i+j+1, update.Target)
}
// Now record that we made it to the front of the history.
if err := s.addVersion(ctx, tx, HistoryRow{
Timestamp: time.Now(),
Digest: curHash,
Schema: s.Current,
}); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("upgrades failed: %w", err)
}
s.logf("Schema successfully updated to digest %s", curHash)
return nil
}
func (s *Schema) addVersion(ctx context.Context, tx *sql.Tx, version HistoryRow) error {
_, err := tx.ExecContext(ctx, queryHistoryInsert,
version.Timestamp.UnixMicro(), version.Digest, compress(version.Schema))
if err != nil {
return fmt.Errorf("record schema %s: %w", version.Digest, err)
}
return nil
}
func (s *Schema) firstPendingUpdate(digest string) int {
for i := len(s.Updates) - 1; i >= 0; i-- {
if s.Updates[i].Source == digest {
return i
}
}
return -1
}
// Check reports an error if there are consistency problems with the schema
// definition that prevent it from being applied.
//
// A Schema is consistent if it has a non-empty Current schema text, all the
// update rules are correctly stitched (prev.Target == next.Source), and the
// last update rule in the sequence has the current schema as its target.
func (s *Schema) Check() error {
if s.Current == "" {
return errors.New("no current schema is defined")
}
hc, err := SQLDigest(s.Current)
if err != nil {
return err
}
var errs []error
var last string
for i, u := range s.Updates {
if u.Source == "" {
errs = append(errs, fmt.Errorf("upgrade %d: missing source", i+1))
}
if u.Target == "" {
errs = append(errs, fmt.Errorf("upgrade %d: missing target", i+1))
}
if u.Apply == nil {
errs = append(errs, fmt.Errorf("upgrade %d: missing Apply function", i+1))
}
if last != "" && u.Source != last {
errs = append(errs, fmt.Errorf("upgrade %d: want source %s, got %s", i+1, last, u.Source))
}
last = u.Target
}
if last != "" && last != hc {
errs = append(errs, fmt.Errorf("missing upgrade from %s to target %s", last, hc))
}
return errors.Join(errs...)
}
// History reports the history of schema upgrades recorded by db in
// chronological order.
func History(ctx context.Context, db DBConn) ([]HistoryRow, error) {
rows, err := db.QueryContext(ctx, queryHistoryRows)
if err != nil {
return nil, err
}
defer rows.Close()
var out []HistoryRow
for rows.Next() {
var ts int64
var digest string
var schemaBytes []byte
if err := rows.Scan(&ts, &digest, &schemaBytes); err != nil {
return nil, fmt.Errorf("scan history: %w", err)
}
out = append(out, HistoryRow{
Timestamp: time.UnixMicro(ts).UTC(),
Digest: digest,
Schema: uncompress(schemaBytes),
})
}
return out, nil
}
// HistoryRow is a row in the schema history maintained by the Schema type.
type HistoryRow struct {
Timestamp time.Time `json:"timestamp"` // In UTC
Digest string `json:"digest"` // The digest of the schema at this update
Schema string `json:"sql,omitempty"` // The SQL of the schema at this update
}
func schemaDigest(sr []schemaRow) string {
// N.B. We don't include the SQL in the hash for tables, since it can be
// mangled by ALTER TABLE executions. We rely on the Columns instead.
//
// For other types with SQL definitions (e.g., views) we use the SQL with
// the whitespace normalized, since that is not affected by ALTER TABLE.
for i, r := range sr {
if r.Type == "table" {
sr[i].SQL = ""
} else {
sr[i].SQL = cleanSQL(sr[i].SQL)
}
}
h := sha256.New()
json.NewEncoder(h).Encode(sr)
return hex.EncodeToString(h.Sum(nil))
}
// SQLDigest computes a hex-encoded SHA256 digest of the SQLite schema encoded
// by the specified string.
func SQLDigest(text string) (string, error) {
sr, err := schemaTextToRows(context.Background(), text)
if err != nil {
return "", err
}
return schemaDigest(sr), nil
}
// DBDigest computes a hex-encoded SHA256 digest of the SQLite schema encoded in
// the specified database.
func DBDigest(ctx context.Context, db DBConn) (string, error) {
sr, err := readSchema(ctx, db, "main")
if err != nil {
return "", err
}
return schemaDigest(sr), nil
}
func compress(text string) []byte {
e, err := zstd.NewWriter(io.Discard)
if err != nil {
panic(fmt.Sprintf("NewWriter: %v", err))
}
return e.EncodeAll([]byte(text), nil)
}
func uncompress(blob []byte) string {
if len(blob) == 0 {
return ""
}
d, err := zstd.NewReader(bytes.NewReader(nil))
if err != nil {
panic(fmt.Sprintf("NewReader: %v", err))
}
dec, err := d.DecodeAll(blob, nil)
if err != nil {
return string(blob)
}
return string(dec)
}