Skip to content

Commit 9486411

Browse files
authored
Rework context cancellation. (#251)
1 parent befed7c commit 9486411

File tree

7 files changed

+90
-64
lines changed

7 files changed

+90
-64
lines changed

blob.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ var _ io.ReadWriteSeeker = &Blob{}
3131
//
3232
// https://sqlite.org/c3ref/blob_open.html
3333
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
34+
if c.interrupt.Err() != nil {
35+
return nil, INTERRUPT
36+
}
37+
3438
defer c.arena.mark()()
3539
blobPtr := c.arena.new(ptrlen)
3640
dbPtr := c.arena.string(db)
@@ -42,7 +46,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
4246
flags = 1
4347
}
4448

45-
c.checkInterrupt()
4649
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
4750
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
4851
stk_t(row), stk_t(flags), stk_t(blobPtr)))
@@ -253,7 +256,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
253256
//
254257
// https://sqlite.org/c3ref/blob_reopen.html
255258
func (b *Blob) Reopen(row int64) error {
256-
b.c.checkInterrupt()
259+
if b.c.interrupt.Err() != nil {
260+
return INTERRUPT
261+
}
257262
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
258263
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
259264
b.offset = 0

config.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,14 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
275275
//
276276
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
277277
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
278+
if c.interrupt.Err() != nil {
279+
return 0, 0, INTERRUPT
280+
}
281+
278282
defer c.arena.mark()()
279283
nLogPtr := c.arena.new(ptrlen)
280284
nCkptPtr := c.arena.new(ptrlen)
281285
schemaPtr := c.arena.string(schema)
282-
283-
c.checkInterrupt()
284286
rc := res_t(c.call("sqlite3_wal_checkpoint_v2",
285287
stk_t(c.handle), stk_t(schemaPtr), stk_t(mode),
286288
stk_t(nLogPtr), stk_t(nCkptPtr)))

conn.go

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ type Conn struct {
4040
busylst time.Time
4141
arena arena
4242
handle ptr_t
43-
pending ptr_t
44-
stepped bool
4543
gosched uint8
4644
}
4745

@@ -167,9 +165,6 @@ func (c *Conn) Close() error {
167165
return nil
168166
}
169167

170-
c.call("sqlite3_finalize", stk_t(c.pending))
171-
c.pending = 0
172-
173168
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
174169
if err := c.error(rc); err != nil {
175170
return err
@@ -184,10 +179,15 @@ func (c *Conn) Close() error {
184179
//
185180
// https://sqlite.org/c3ref/exec.html
186181
func (c *Conn) Exec(sql string) error {
182+
if c.interrupt.Err() != nil {
183+
return INTERRUPT
184+
}
185+
return c.exec(sql)
186+
}
187+
188+
func (c *Conn) exec(sql string) error {
187189
defer c.arena.mark()()
188190
textPtr := c.arena.string(sql)
189-
190-
c.checkInterrupt()
191191
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
192192
return c.error(rc, sql)
193193
}
@@ -207,13 +207,15 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
207207
if len(sql) > _MAX_SQL_LENGTH {
208208
return nil, "", TOOBIG
209209
}
210+
if c.interrupt.Err() != nil {
211+
return nil, "", INTERRUPT
212+
}
210213

211214
defer c.arena.mark()()
212215
stmtPtr := c.arena.new(ptrlen)
213216
tailPtr := c.arena.new(ptrlen)
214217
textPtr := c.arena.string(sql)
215218

216-
c.checkInterrupt()
217219
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
218220
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
219221
stk_t(stmtPtr), stk_t(tailPtr)))
@@ -343,42 +345,9 @@ func (c *Conn) GetInterrupt() context.Context {
343345
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
344346
old = c.interrupt
345347
c.interrupt = ctx
346-
347-
if ctx == old {
348-
return old
349-
}
350-
351-
// An active SQL statement prevents SQLite from ignoring an interrupt
352-
// that comes before any other statements are started.
353-
if c.pending == 0 {
354-
defer c.arena.mark()()
355-
stmtPtr := c.arena.new(ptrlen)
356-
textPtr := c.arena.string(`SELECT 0 UNION ALL SELECT 0`)
357-
c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(textPtr), math.MaxUint64,
358-
stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0)
359-
c.pending = util.Read32[ptr_t](c.mod, stmtPtr)
360-
}
361-
362-
if c.stepped && ctx.Err() == nil {
363-
c.call("sqlite3_reset", stk_t(c.pending))
364-
c.stepped = false
365-
} else {
366-
c.checkInterrupt()
367-
}
368348
return old
369349
}
370350

371-
func (c *Conn) checkInterrupt() {
372-
if c.interrupt.Err() == nil {
373-
return
374-
}
375-
if !c.stepped {
376-
c.call("sqlite3_step", stk_t(c.pending))
377-
c.stepped = true
378-
}
379-
c.call("sqlite3_interrupt", stk_t(c.handle))
380-
}
381-
382351
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
383352
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
384353
if c.gosched++; c.gosched%16 == 0 {

driver/driver_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,62 @@ func Test_BeginTx(t *testing.T) {
199199
}
200200
}
201201

202+
func Test_nested_context(t *testing.T) {
203+
t.Parallel()
204+
tmp := memdb.TestDB(t)
205+
206+
db, err := sql.Open("sqlite3", tmp)
207+
if err != nil {
208+
t.Fatal(err)
209+
}
210+
defer db.Close()
211+
212+
tx, err := db.Begin()
213+
if err != nil {
214+
t.Fatal(err)
215+
}
216+
defer tx.Rollback()
217+
218+
outer, err := tx.Query(`SELECT value FROM generate_series(0)`)
219+
if err != nil {
220+
t.Fatal(err)
221+
}
222+
defer outer.Close()
223+
224+
want := func(rows *sql.Rows, want int) {
225+
t.Helper()
226+
227+
var got int
228+
rows.Next()
229+
if err := rows.Scan(&got); err != nil {
230+
t.Fatal(err)
231+
}
232+
if got != want {
233+
t.Errorf("got %d, want %d", got, want)
234+
}
235+
}
236+
237+
want(outer, 0)
238+
239+
ctx, cancel := context.WithCancel(context.Background())
240+
defer cancel()
241+
242+
inner, err := tx.QueryContext(ctx, `SELECT value FROM generate_series(0)`)
243+
if err != nil {
244+
t.Fatal(err)
245+
}
246+
defer inner.Close()
247+
248+
want(inner, 0)
249+
cancel()
250+
251+
if inner.Next() || !errors.Is(inner.Err(), sqlite3.INTERRUPT) {
252+
t.Fatal(inner.Err())
253+
}
254+
255+
want(outer, 1)
256+
}
257+
202258
func Test_Prepare(t *testing.T) {
203259
t.Parallel()
204260
tmp := memdb.TestDB(t)

stmt.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ func (s *Stmt) Busy() bool {
106106
//
107107
// https://sqlite.org/c3ref/step.html
108108
func (s *Stmt) Step() bool {
109-
s.c.checkInterrupt()
109+
if s.c.interrupt.Err() != nil {
110+
s.err = INTERRUPT
111+
return false
112+
}
113+
110114
rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
111115
switch rc {
112116
case _ROW:

txn.go

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package sqlite3
22

33
import (
44
"context"
5-
"errors"
65
"math/rand"
76
"runtime"
87
"strconv"
@@ -25,7 +24,7 @@ type Txn struct {
2524
// https://sqlite.org/lang_transaction.html
2625
func (c *Conn) Begin() Txn {
2726
// BEGIN even if interrupted.
28-
err := c.txnExecInterrupted(`BEGIN DEFERRED`)
27+
err := c.exec(`BEGIN DEFERRED`)
2928
if err != nil {
3029
panic(err)
3130
}
@@ -120,7 +119,7 @@ func (tx Txn) Commit() error {
120119
//
121120
// https://sqlite.org/lang_transaction.html
122121
func (tx Txn) Rollback() error {
123-
return tx.c.txnExecInterrupted(`ROLLBACK`)
122+
return tx.c.exec(`ROLLBACK`)
124123
}
125124

126125
// Savepoint is a marker within a transaction
@@ -143,7 +142,7 @@ func (c *Conn) Savepoint() Savepoint {
143142
// Names can be reused, but this makes catching bugs more likely.
144143
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
145144

146-
err := c.txnExecInterrupted(`SAVEPOINT ` + name)
145+
err := c.exec(`SAVEPOINT ` + name)
147146
if err != nil {
148147
panic(err)
149148
}
@@ -199,7 +198,7 @@ func (s Savepoint) Release(errp *error) {
199198
return
200199
}
201200
// ROLLBACK and RELEASE even if interrupted.
202-
err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
201+
err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
203202
if err != nil {
204203
panic(err)
205204
}
@@ -212,17 +211,7 @@ func (s Savepoint) Release(errp *error) {
212211
// https://sqlite.org/lang_transaction.html
213212
func (s Savepoint) Rollback() error {
214213
// ROLLBACK even if interrupted.
215-
return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name)
216-
}
217-
218-
func (c *Conn) txnExecInterrupted(sql string) error {
219-
err := c.Exec(sql)
220-
if errors.Is(err, INTERRUPT) {
221-
old := c.SetInterrupt(context.Background())
222-
defer c.SetInterrupt(old)
223-
err = c.Exec(sql)
224-
}
225-
return err
214+
return s.c.exec(`ROLLBACK TO ` + s.name)
226215
}
227216

228217
// TxnState determines the transaction state of a database.

vtab.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ func implements[T any](typ reflect.Type) bool {
7979
//
8080
// https://sqlite.org/c3ref/declare_vtab.html
8181
func (c *Conn) DeclareVTab(sql string) error {
82+
if c.interrupt.Err() != nil {
83+
return INTERRUPT
84+
}
8285
defer c.arena.mark()()
8386
textPtr := c.arena.string(sql)
84-
85-
c.checkInterrupt()
8687
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr)))
8788
return c.error(rc)
8889
}

0 commit comments

Comments
 (0)