Skip to content

Commit 7e1a61d

Browse files
committed
Fix context cancellation racy handling
[why] Context cancellation goroutine is not in sync with Next() method lifetime. It leads to sql.ErrNoRows instead of context.Canceled often (easy to reproduce). It leads to interruption of next query executed on same connection (harder to reproduce). [how] Do query in goroutine, wait when interruption done. [testing] Add unit test that reproduces error cases.
1 parent d3c6909 commit 7e1a61d

File tree

2 files changed

+147
-37
lines changed

2 files changed

+147
-37
lines changed

sqlite3.go

+59-37
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ type SQLiteRows struct {
328328
decltype []string
329329
cls bool
330330
closed bool
331-
done chan struct{}
331+
ctx context.Context // no better alternative to pass context into Next() method
332332
}
333333

334334
type functionInfo struct {
@@ -1846,22 +1846,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
18461846
decltype: nil,
18471847
cls: s.cls,
18481848
closed: false,
1849-
done: make(chan struct{}),
1850-
}
1851-
1852-
if ctxdone := ctx.Done(); ctxdone != nil {
1853-
go func(db *C.sqlite3) {
1854-
select {
1855-
case <-ctxdone:
1856-
select {
1857-
case <-rows.done:
1858-
default:
1859-
C.sqlite3_interrupt(db)
1860-
rows.Close()
1861-
}
1862-
case <-rows.done:
1863-
}
1864-
}(s.c.db)
1849+
ctx: ctx,
18651850
}
18661851

18671852
return rows, nil
@@ -1889,29 +1874,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
18891874
return s.exec(context.Background(), list)
18901875
}
18911876

1877+
// exec executes a query that doesn't return rows. Attempts to honor context timeout.
18921878
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
1879+
if ctx.Done() == nil {
1880+
return s.execSync(args)
1881+
}
1882+
1883+
type result struct {
1884+
r driver.Result
1885+
err error
1886+
}
1887+
resultCh := make(chan result)
1888+
go func() {
1889+
r, err := s.execSync(args)
1890+
resultCh <- result{r, err}
1891+
}()
1892+
select {
1893+
case rv := <- resultCh:
1894+
return rv.r, rv.err
1895+
case <-ctx.Done():
1896+
select {
1897+
case <-resultCh: // no need to interrupt
1898+
default:
1899+
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
1900+
C.sqlite3_interrupt(s.c.db)
1901+
<-resultCh // ensure goroutine completed
1902+
}
1903+
return nil, ctx.Err()
1904+
}
1905+
}
1906+
1907+
func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
18931908
if err := s.bind(args); err != nil {
18941909
C.sqlite3_reset(s.s)
18951910
C.sqlite3_clear_bindings(s.s)
18961911
return nil, err
18971912
}
18981913

1899-
if ctxdone := ctx.Done(); ctxdone != nil {
1900-
done := make(chan struct{})
1901-
defer close(done)
1902-
go func(db *C.sqlite3) {
1903-
select {
1904-
case <-done:
1905-
case <-ctxdone:
1906-
select {
1907-
case <-done:
1908-
default:
1909-
C.sqlite3_interrupt(db)
1910-
}
1911-
}
1912-
}(s.c.db)
1913-
}
1914-
19151914
var rowid, changes C.longlong
19161915
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
19171916
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -1932,9 +1931,6 @@ func (rc *SQLiteRows) Close() error {
19321931
return nil
19331932
}
19341933
rc.closed = true
1935-
if rc.done != nil {
1936-
close(rc.done)
1937-
}
19381934
if rc.cls {
19391935
rc.s.mu.Unlock()
19401936
return rc.s.Close()
@@ -1978,13 +1974,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
19781974
return rc.declTypes()
19791975
}
19801976

1981-
// Next move cursor to next.
1977+
// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
19821978
func (rc *SQLiteRows) Next(dest []driver.Value) error {
19831979
rc.s.mu.Lock()
19841980
defer rc.s.mu.Unlock()
1981+
19851982
if rc.s.closed {
19861983
return io.EOF
19871984
}
1985+
1986+
if rc.ctx.Done() == nil {
1987+
return rc.nextSyncLocked(dest)
1988+
}
1989+
resultCh := make(chan error)
1990+
go func() {
1991+
resultCh <- rc.nextSyncLocked(dest)
1992+
}()
1993+
select {
1994+
case err := <- resultCh:
1995+
return err
1996+
case <-rc.ctx.Done():
1997+
select {
1998+
case <-resultCh: // no need to interrupt
1999+
default:
2000+
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
2001+
C.sqlite3_interrupt(rc.s.c.db)
2002+
<-resultCh // ensure goroutine completed
2003+
}
2004+
return rc.ctx.Err()
2005+
}
2006+
}
2007+
2008+
// nextSyncLocked moves cursor to next; must be called with locked mutex.
2009+
func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
19882010
rv := C._sqlite3_step_internal(rc.s.s)
19892011
if rv == C.SQLITE_DONE {
19902012
return io.EOF

sqlite3_go18_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"io/ioutil"
1515
"math/rand"
1616
"os"
17+
"sync"
1718
"testing"
1819
"time"
1920
)
@@ -135,6 +136,93 @@ func TestShortTimeout(t *testing.T) {
135136
}
136137
}
137138

139+
func TestQueryRowContextCancel(t *testing.T) {
140+
srcTempFilename := TempFilename(t)
141+
defer os.Remove(srcTempFilename)
142+
143+
db, err := sql.Open("sqlite3", srcTempFilename)
144+
if err != nil {
145+
t.Fatal(err)
146+
}
147+
defer db.Close()
148+
initDatabase(t, db, 100)
149+
150+
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
151+
var keyID string
152+
unexpectedErrors := make(map[string]int)
153+
for i := 0; i < 10000; i++ {
154+
ctx, cancel := context.WithCancel(context.Background())
155+
row := db.QueryRowContext(ctx, query)
156+
157+
cancel()
158+
// it is fine to get "nil" as context cancellation can be handled with delay
159+
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
160+
if err.Error() == "sql: Rows are closed" {
161+
// see https://github.com/golang/go/issues/24431
162+
// fixed in 1.11.1 to properly return context error
163+
continue
164+
}
165+
unexpectedErrors[err.Error()]++
166+
}
167+
}
168+
for errText, count := range unexpectedErrors {
169+
t.Error(errText, count)
170+
}
171+
}
172+
173+
func TestQueryRowContextCancelParallel(t *testing.T) {
174+
srcTempFilename := TempFilename(t)
175+
defer os.Remove(srcTempFilename)
176+
177+
db, err := sql.Open("sqlite3", srcTempFilename)
178+
if err != nil {
179+
t.Fatal(err)
180+
}
181+
db.SetMaxOpenConns(10)
182+
db.SetMaxIdleConns(5)
183+
184+
defer db.Close()
185+
initDatabase(t, db, 100)
186+
187+
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
188+
wg := sync.WaitGroup{}
189+
defer wg.Wait()
190+
191+
testCtx, cancel := context.WithCancel(context.Background())
192+
defer cancel()
193+
194+
for i := 0; i < 50; i++ {
195+
wg.Add(1)
196+
go func() {
197+
defer wg.Done()
198+
199+
var keyID string
200+
for {
201+
select {
202+
case <-testCtx.Done():
203+
return
204+
default:
205+
}
206+
ctx, cancel := context.WithCancel(context.Background())
207+
row := db.QueryRowContext(ctx, query)
208+
209+
cancel()
210+
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
211+
}
212+
}()
213+
}
214+
215+
var keyID string
216+
for i := 0; i < 10000; i++ {
217+
// note that testCtx is not cancelled during query execution
218+
row := db.QueryRowContext(testCtx, query)
219+
220+
if err := row.Scan(&keyID); err != nil {
221+
t.Fatal(i, err)
222+
}
223+
}
224+
}
225+
138226
func TestExecCancel(t *testing.T) {
139227
db, err := sql.Open("sqlite3", ":memory:")
140228
if err != nil {

0 commit comments

Comments
 (0)