Skip to content

Commit

Permalink
return errors in favor of panicking, when possible
Browse files Browse the repository at this point in the history
    This commit changes copyist to instead return *sessionError via the
    sql driver in favor of panicking. Note in some cases, namely
    proxyRows.Columns, copyist will still panic as the proper behavior
    is otherwise unclear. Most usecases will encounter an error either
    before or after calling .Columns, so it is unlikely for users to
    encounter this panic.

    Returning errors instead of panicking fixes deadlocks that can occur
    when a driver.Conn is returning a driver.Rows inside of a
    transaction that recovers panics. Namely, cockroach-go [1]. A panic
    within a SQL Driver will cause the connections read mutex to NOT be
    unlocked as control should have been passed down to the sql.Rows
    (see queryDC in sql/sql.go). When the transaction helper method
    attempts to execute tx.Rollback, it will deadlock upon attempting to
    acquire the mutex.

    ```
	// Cancel the Tx to release any active R-closemu locks.
	// This is safe to do because tx.done has already transitioned
	// from 0 to 1. Hold the W-closemu lock prior to rollback
	// to ensure no other connection has an active query.
	tx.cancel()
	tx.closemu.Lock()
	tx.closemu.Unlock()
    ```

    Note that similar behavior would be encountered if instead
    attempting to call .Close on the Conn object.

    Users should be able to make use of transaction helpers that recover
    panics as behavior may become unpredictable within system that
    recover the panic elsewhere such as an HTTP server. Repositories
    under the cockroachdb organization should also work well together.

    [1] https://github.com/cockroachdb/cockroach-go/blob/master/crdb/common.go#L40
  • Loading branch information
chrisseto committed Aug 23, 2022
1 parent eacfd65 commit 5b4bb61
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 50 deletions.
28 changes: 20 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ func (c *proxyConn) ExecContext(
return &proxyResult{res: res}, nil
}

rec := currentSession.VerifyRecordWithStringArg(ConnExec, query)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecordWithStringArg(ConnExec, query)
if err != nil {
return nil, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -130,8 +133,11 @@ func (c *proxyConn) PrepareContext(ctx context.Context, query string) (driver.St
return &proxyStmt{stmt: stmt}, nil
}

rec := currentSession.VerifyRecordWithStringArg(ConnPrepare, query)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecordWithStringArg(ConnPrepare, query)
if err != nil {
return nil, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -169,8 +175,11 @@ func (c *proxyConn) QueryContext(
return &proxyRows{rows: rows}, nil
}

rec := currentSession.VerifyRecordWithStringArg(ConnQuery, query)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecordWithStringArg(ConnQuery, query)
if err != nil {
return nil, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -232,8 +241,11 @@ func (c *proxyConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.
return &proxyTx{tx: tx}, nil
}

rec := currentSession.VerifyRecord(ConnBegin)
err, _ := rec.Args[0].(error)
rec, err := currentSession.VerifyRecord(ConnBegin)
if err != nil {
return nil, err
}
err, _ = rec.Args[0].(error)
if err != nil {
return nil, err
}
Expand Down
26 changes: 15 additions & 11 deletions copyist.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ var registered map[string]*proxyDriver
// The Register method takes the name of the SQL driver to be wrapped (e.g.
// "postgres"). Below is an example of how copyist.Register should be invoked.
//
// copyist.Register("postgres")
// copyist.Register("postgres")
//
// Note that Register can only be called once for a given driver; subsequent
// attempts will fail with an error. In addition, the same copyist driver must
Expand Down Expand Up @@ -139,22 +139,22 @@ func SetSessionInit(callback SessionInitCallback) {
// alongside the calling test file. If playing back, then the recording will
// be fetched from that recording file. Here is a typical calling pattern:
//
// func init() {
// copyist.Register("postgres")
// }
// func init() {
// copyist.Register("postgres")
// }
//
// func TestMyStuff(t *testing.T) {
// defer copyist.Open(t).Close()
// ...
// }
// func TestMyStuff(t *testing.T) {
// defer copyist.Open(t).Close()
// ...
// }
//
// The call to Open will initiate a new recording session. The deferred call to
// Close will complete the recording session and write the recording to a file
// in the testdata/ directory, like:
//
// mystuff_test.go
// testdata/
// mystuff_test.copyist
// mystuff_test.go
// testdata/
// mystuff_test.copyist
//
// Each test or sub-test that needs to be executed independently needs to record
// its own session.
Expand Down Expand Up @@ -215,6 +215,10 @@ func OpenSource(t testingT, source Source, recordingName string) io.Closer {
panic(r)
}

if currentSession.verificationErr != nil {
t.Fatalf("%+v\n", currentSession.verificationErr.error)
}

currentSession.Close()
currentSession = nil
return nil
Expand Down
8 changes: 4 additions & 4 deletions copyist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (t *mockTestingT) Fatalf(format string, args ...interface{}) {
fmt.Fprintf(&t.buf, format, args...)
}

func TestSessionPanicsAreCaught(t *testing.T) {
func TestSessionFailuresAreFatalfd(t *testing.T) {
// Enter playback mode.
*recordFlag = false
visitedRecording = true
Expand All @@ -102,16 +102,16 @@ func TestSessionPanicsAreCaught(t *testing.T) {

m := &mockTestingT{T: t}
defer func() {
require.Equal(t, "no recording exists with this name: TestSessionPanicsAreCaught\n",
require.Equal(t, "no recording exists with this name: TestSessionFailuresAreFatalfd\n",
m.buf.String())
}()

defer Open(m).Close()

db, err := sql.Open("copyist_postgres2", "")
require.NoError(t, err)
// NB: This will panic, but the panic will be caught by the copyist closer and
// converted into a call to testing.T.Fatalf.
// NB: This will return an error, but the the copyist closer will track the
// first error and convert it into a call to testing.T.Fatalf.
db.Query("SELECT 1")
}

Expand Down
7 changes: 5 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ func (d *proxyDriver) Open(name string) (driver.Conn, error) {
return &proxyConn{driver: d, conn: conn, name: name, session: currentSession}, nil
}

rec := currentSession.VerifyRecord(DriverOpen)
err, _ := rec.Args[0].(error)
rec, err := currentSession.VerifyRecord(DriverOpen)
if err != nil {
return nil, err
}
err, _ = rec.Args[0].(error)
if err != nil {
return nil, err
}
Expand Down
88 changes: 88 additions & 0 deletions drivertest/commontest/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package commontest_test
import (
"bytes"
"database/sql"
"fmt"
"io"
"testing"

Expand Down Expand Up @@ -94,6 +95,84 @@ func TestOpenReadWriteCloser(t *testing.T) {
rows.Next()
}

func TestRollbackWithRecover(t *testing.T) {
// This bug is only present in playback mode, short circuit if we're
// recording.
if copyist.IsRecording() {
return
}

// This is a regression test for a deadlock when copyist would panic upon
// recording failures. We mount an intentionally out of date source that
// will fail on any action after opening our transaction. Our transaction
// helper will attempt a rollback in the case of an error or a panic, which
// would catch copyist's old behavior of panicking upon out of date
// recordings.
// We assert that we hit an out of date error and that rollback is called
// and returns.
source := CopyistSource(bytes.NewBuffer([]byte(`
1=DriverOpen 1:nil
2=ConnBegin 1:nil
"TestRollbackWithRecover"=1,2`)))

defer leaktest.Check(t)()

var mt mockT

closer := copyist.OpenSource(&mt, source, t.Name())

// Open database.
db, err := sql.Open("copyist_postgres", dataSourceName)
require.NoError(t, err)
defer db.Close()

fnErr, txErr := execTransaction(db, func(tx *sql.Tx) error {
_, err := tx.Query("SELECT 1")
return err
})

require.EqualError(t, fnErr, "too many calls to ConnQuery\n\nDo you need to regenerate the recording with the -record flag?")
require.EqualError(t, txErr, "too many calls to TxRollback\n\nDo you need to regenerate the recording with the -record flag?")

require.NoError(t, closer.Close()) // closer never errors.

// Verify that the call to .Close invokes t.Fatalf with the first session
// error that we encountered.
require.Contains(t, mt.failure, "too many calls to ConnQuery")
// Verify that t.Fatalf includes the stacktrace leading to the call that
// triggered the first error. In this case, we look for the error coming
// from the first closure defined within this test function.
require.Contains(t, mt.failure, fmt.Sprintf("commontest_test.%s.func1", t.Name()))
}

// execTransaction is a transaction helper function that attempts a rollback in
// the case of panics of errors. It returns both the closure error and the
// error of either commiting or rolling back.
// It is intended to mimic the behavior of
// https://github.com/cockroachdb/cockroach-go/blob/21a237074d6c3c35b68ec43e8d0c9e9ed714d21a/crdb/common.go#L38
func execTransaction(db *sql.DB, fn func(*sql.Tx) error) (fnErr error, txErr error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}

defer func() {
if r := recover(); r != nil {
txErr = tx.Rollback()
panic(r)
}

if fnErr == nil {
txErr = tx.Commit()
} else {
txErr = tx.Rollback()
}
}()

return fn(tx), nil
}

func TestIsOpen(t *testing.T) {
require.False(t, copyist.IsOpen())

Expand Down Expand Up @@ -164,3 +243,12 @@ func (s copyistSource) WriteAll([]byte) error {
func CopyistSource(r io.Reader) copyist.Source {
return copyistSource{r}
}

type mockT struct {
failure string
}

func (mockT) Name() string { return "" }
func (t *mockT) Fatalf(format string, args ...interface{}) {
t.failure = fmt.Sprintf(format, args...)
}
1 change: 1 addition & 0 deletions drivertest/pqtestold/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ require (
github.com/jackc/pgx/v4 v4.13.0
github.com/jmoiron/sqlx v1.3.4
github.com/lib/pq v1.10.2
github.com/pkg/errors v0.8.1
github.com/stretchr/testify v1.7.0
)
14 changes: 10 additions & 4 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ func (r *proxyResult) LastInsertId() (int64, error) {
return id, err
}

rec := currentSession.VerifyRecord(ResultLastInsertId)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecord(ResultLastInsertId)
if err != nil {
return 0, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return 0, err
}
Expand All @@ -51,8 +54,11 @@ func (r *proxyResult) RowsAffected() (int64, error) {
return affected, err
}

rec := currentSession.VerifyRecord(ResultRowsAffected)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecord(ResultRowsAffected)
if err != nil {
return 0, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return 0, err
}
Expand Down
12 changes: 9 additions & 3 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ func (r *proxyRows) Columns() []string {
return cols
}

rec := currentSession.VerifyRecord(RowsColumns)
rec, err := currentSession.VerifyRecord(RowsColumns)
if err != nil {
panic(err)
}
return rec.Args[0].([]string)
}

Expand Down Expand Up @@ -70,8 +73,11 @@ func (r *proxyRows) Next(dest []driver.Value) error {
return err
}

rec := currentSession.VerifyRecord(RowsNext)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecord(RowsNext)
if err != nil {
return err
}
err, _ = rec.Args[1].(error)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 5b4bb61

Please sign in to comment.