Skip to content

Commit 590d44c

Browse files
authored
Merge pull request #744 from azavorotnii/ctx_cancel
Fix context cancellation racy handling
2 parents 0cf797e + 7e1a61d commit 590d44c

File tree

2 files changed

+147
-37
lines changed

2 files changed

+147
-37
lines changed

sqlite3.go

+59-37
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ type SQLiteRows struct {
328328
decltype []string
329329
cls bool
330330
closed bool
331-
done chan struct{}
331+
ctx context.Context // no better alternative to pass context into Next() method
332332
}
333333

334334
type functionInfo struct {
@@ -1847,22 +1847,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
18471847
decltype: nil,
18481848
cls: s.cls,
18491849
closed: false,
1850-
done: make(chan struct{}),
1851-
}
1852-
1853-
if ctxdone := ctx.Done(); ctxdone != nil {
1854-
go func(db *C.sqlite3) {
1855-
select {
1856-
case <-ctxdone:
1857-
select {
1858-
case <-rows.done:
1859-
default:
1860-
C.sqlite3_interrupt(db)
1861-
rows.Close()
1862-
}
1863-
case <-rows.done:
1864-
}
1865-
}(s.c.db)
1850+
ctx: ctx,
18661851
}
18671852

18681853
return rows, nil
@@ -1890,29 +1875,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
18901875
return s.exec(context.Background(), list)
18911876
}
18921877

1878+
// exec executes a query that doesn't return rows. Attempts to honor context timeout.
18931879
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
1880+
if ctx.Done() == nil {
1881+
return s.execSync(args)
1882+
}
1883+
1884+
type result struct {
1885+
r driver.Result
1886+
err error
1887+
}
1888+
resultCh := make(chan result)
1889+
go func() {
1890+
r, err := s.execSync(args)
1891+
resultCh <- result{r, err}
1892+
}()
1893+
select {
1894+
case rv := <- resultCh:
1895+
return rv.r, rv.err
1896+
case <-ctx.Done():
1897+
select {
1898+
case <-resultCh: // no need to interrupt
1899+
default:
1900+
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
1901+
C.sqlite3_interrupt(s.c.db)
1902+
<-resultCh // ensure goroutine completed
1903+
}
1904+
return nil, ctx.Err()
1905+
}
1906+
}
1907+
1908+
func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
18941909
if err := s.bind(args); err != nil {
18951910
C.sqlite3_reset(s.s)
18961911
C.sqlite3_clear_bindings(s.s)
18971912
return nil, err
18981913
}
18991914

1900-
if ctxdone := ctx.Done(); ctxdone != nil {
1901-
done := make(chan struct{})
1902-
defer close(done)
1903-
go func(db *C.sqlite3) {
1904-
select {
1905-
case <-done:
1906-
case <-ctxdone:
1907-
select {
1908-
case <-done:
1909-
default:
1910-
C.sqlite3_interrupt(db)
1911-
}
1912-
}
1913-
}(s.c.db)
1914-
}
1915-
19161915
var rowid, changes C.longlong
19171916
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
19181917
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -1933,9 +1932,6 @@ func (rc *SQLiteRows) Close() error {
19331932
return nil
19341933
}
19351934
rc.closed = true
1936-
if rc.done != nil {
1937-
close(rc.done)
1938-
}
19391935
if rc.cls {
19401936
rc.s.mu.Unlock()
19411937
return rc.s.Close()
@@ -1979,13 +1975,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
19791975
return rc.declTypes()
19801976
}
19811977

1982-
// Next move cursor to next.
1978+
// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
19831979
func (rc *SQLiteRows) Next(dest []driver.Value) error {
19841980
rc.s.mu.Lock()
19851981
defer rc.s.mu.Unlock()
1982+
19861983
if rc.s.closed {
19871984
return io.EOF
19881985
}
1986+
1987+
if rc.ctx.Done() == nil {
1988+
return rc.nextSyncLocked(dest)
1989+
}
1990+
resultCh := make(chan error)
1991+
go func() {
1992+
resultCh <- rc.nextSyncLocked(dest)
1993+
}()
1994+
select {
1995+
case err := <- resultCh:
1996+
return err
1997+
case <-rc.ctx.Done():
1998+
select {
1999+
case <-resultCh: // no need to interrupt
2000+
default:
2001+
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
2002+
C.sqlite3_interrupt(rc.s.c.db)
2003+
<-resultCh // ensure goroutine completed
2004+
}
2005+
return rc.ctx.Err()
2006+
}
2007+
}
2008+
2009+
// nextSyncLocked moves cursor to next; must be called with locked mutex.
2010+
func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
19892011
rv := C._sqlite3_step_internal(rc.s.s)
19902012
if rv == C.SQLITE_DONE {
19912013
return io.EOF

sqlite3_go18_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"io/ioutil"
1515
"math/rand"
1616
"os"
17+
"sync"
1718
"testing"
1819
"time"
1920
)
@@ -135,6 +136,93 @@ func TestShortTimeout(t *testing.T) {
135136
}
136137
}
137138

139+
func TestQueryRowContextCancel(t *testing.T) {
140+
srcTempFilename := TempFilename(t)
141+
defer os.Remove(srcTempFilename)
142+
143+
db, err := sql.Open("sqlite3", srcTempFilename)
144+
if err != nil {
145+
t.Fatal(err)
146+
}
147+
defer db.Close()
148+
initDatabase(t, db, 100)
149+
150+
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
151+
var keyID string
152+
unexpectedErrors := make(map[string]int)
153+
for i := 0; i < 10000; i++ {
154+
ctx, cancel := context.WithCancel(context.Background())
155+
row := db.QueryRowContext(ctx, query)
156+
157+
cancel()
158+
// it is fine to get "nil" as context cancellation can be handled with delay
159+
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
160+
if err.Error() == "sql: Rows are closed" {
161+
// see https://github.com/golang/go/issues/24431
162+
// fixed in 1.11.1 to properly return context error
163+
continue
164+
}
165+
unexpectedErrors[err.Error()]++
166+
}
167+
}
168+
for errText, count := range unexpectedErrors {
169+
t.Error(errText, count)
170+
}
171+
}
172+
173+
func TestQueryRowContextCancelParallel(t *testing.T) {
174+
srcTempFilename := TempFilename(t)
175+
defer os.Remove(srcTempFilename)
176+
177+
db, err := sql.Open("sqlite3", srcTempFilename)
178+
if err != nil {
179+
t.Fatal(err)
180+
}
181+
db.SetMaxOpenConns(10)
182+
db.SetMaxIdleConns(5)
183+
184+
defer db.Close()
185+
initDatabase(t, db, 100)
186+
187+
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
188+
wg := sync.WaitGroup{}
189+
defer wg.Wait()
190+
191+
testCtx, cancel := context.WithCancel(context.Background())
192+
defer cancel()
193+
194+
for i := 0; i < 50; i++ {
195+
wg.Add(1)
196+
go func() {
197+
defer wg.Done()
198+
199+
var keyID string
200+
for {
201+
select {
202+
case <-testCtx.Done():
203+
return
204+
default:
205+
}
206+
ctx, cancel := context.WithCancel(context.Background())
207+
row := db.QueryRowContext(ctx, query)
208+
209+
cancel()
210+
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
211+
}
212+
}()
213+
}
214+
215+
var keyID string
216+
for i := 0; i < 10000; i++ {
217+
// note that testCtx is not cancelled during query execution
218+
row := db.QueryRowContext(testCtx, query)
219+
220+
if err := row.Scan(&keyID); err != nil {
221+
t.Fatal(i, err)
222+
}
223+
}
224+
}
225+
138226
func TestExecCancel(t *testing.T) {
139227
db, err := sql.Open("sqlite3", ":memory:")
140228
if err != nil {

0 commit comments

Comments
 (0)