From 5b4bb61c61700d26788aafba90167ff5c8b3e644 Mon Sep 17 00:00:00 2001 From: Chris Seto Date: Tue, 23 Aug 2022 14:05:57 +0000 Subject: [PATCH] return errors in favor of panicking, when possible This commit changes copyist to instead return *sessionError via the sql driver in favor of panicking. Note in some cases, namely proxyRows.Columns, copyist will still panic as the proper behavior is otherwise unclear. Most usecases will encounter an error either before or after calling .Columns, so it is unlikely for users to encounter this panic. Returning errors instead of panicking fixes deadlocks that can occur when a driver.Conn is returning a driver.Rows inside of a transaction that recovers panics. Namely, cockroach-go [1]. A panic within a SQL Driver will cause the connections read mutex to NOT be unlocked as control should have been passed down to the sql.Rows (see queryDC in sql/sql.go). When the transaction helper method attempts to execute tx.Rollback, it will deadlock upon attempting to acquire the mutex. ``` // Cancel the Tx to release any active R-closemu locks. // This is safe to do because tx.done has already transitioned // from 0 to 1. Hold the W-closemu lock prior to rollback // to ensure no other connection has an active query. tx.cancel() tx.closemu.Lock() tx.closemu.Unlock() ``` Note that similar behavior would be encountered if instead attempting to call .Close on the Conn object. Users should be able to make use of transaction helpers that recover panics as behavior may become unpredictable within system that recover the panic elsewhere such as an HTTP server. Repositories under the cockroachdb organization should also work well together. [1] https://github.com/cockroachdb/cockroach-go/blob/master/crdb/common.go#L40 --- conn.go | 28 ++++++--- copyist.go | 26 ++++---- copyist_test.go | 8 +-- driver.go | 7 ++- drivertest/commontest/common_test.go | 88 ++++++++++++++++++++++++++++ drivertest/pqtestold/go.sum | 1 + go.mod | 1 + result.go | 14 +++-- rows.go | 12 +++- session.go | 35 ++++++++--- stmt.go | 19 ++++-- tx.go | 14 +++-- 12 files changed, 203 insertions(+), 50 deletions(-) diff --git a/conn.go b/conn.go index 984c2c8..239024e 100644 --- a/conn.go +++ b/conn.go @@ -97,8 +97,11 @@ func (c *proxyConn) ExecContext( return &proxyResult{res: res}, nil } - rec := currentSession.VerifyRecordWithStringArg(ConnExec, query) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecordWithStringArg(ConnExec, query) + if err != nil { + return nil, err + } + err, _ = rec.Args[1].(error) if err != nil { return nil, err } @@ -130,8 +133,11 @@ func (c *proxyConn) PrepareContext(ctx context.Context, query string) (driver.St return &proxyStmt{stmt: stmt}, nil } - rec := currentSession.VerifyRecordWithStringArg(ConnPrepare, query) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecordWithStringArg(ConnPrepare, query) + if err != nil { + return nil, err + } + err, _ = rec.Args[1].(error) if err != nil { return nil, err } @@ -169,8 +175,11 @@ func (c *proxyConn) QueryContext( return &proxyRows{rows: rows}, nil } - rec := currentSession.VerifyRecordWithStringArg(ConnQuery, query) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecordWithStringArg(ConnQuery, query) + if err != nil { + return nil, err + } + err, _ = rec.Args[1].(error) if err != nil { return nil, err } @@ -232,8 +241,11 @@ func (c *proxyConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver. return &proxyTx{tx: tx}, nil } - rec := currentSession.VerifyRecord(ConnBegin) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(ConnBegin) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } diff --git a/copyist.go b/copyist.go index 00098ed..73b0aa3 100644 --- a/copyist.go +++ b/copyist.go @@ -94,7 +94,7 @@ var registered map[string]*proxyDriver // The Register method takes the name of the SQL driver to be wrapped (e.g. // "postgres"). Below is an example of how copyist.Register should be invoked. // -// copyist.Register("postgres") +// copyist.Register("postgres") // // Note that Register can only be called once for a given driver; subsequent // attempts will fail with an error. In addition, the same copyist driver must @@ -139,22 +139,22 @@ func SetSessionInit(callback SessionInitCallback) { // alongside the calling test file. If playing back, then the recording will // be fetched from that recording file. Here is a typical calling pattern: // -// func init() { -// copyist.Register("postgres") -// } +// func init() { +// copyist.Register("postgres") +// } // -// func TestMyStuff(t *testing.T) { -// defer copyist.Open(t).Close() -// ... -// } +// func TestMyStuff(t *testing.T) { +// defer copyist.Open(t).Close() +// ... +// } // // The call to Open will initiate a new recording session. The deferred call to // Close will complete the recording session and write the recording to a file // in the testdata/ directory, like: // -// mystuff_test.go -// testdata/ -// mystuff_test.copyist +// mystuff_test.go +// testdata/ +// mystuff_test.copyist // // Each test or sub-test that needs to be executed independently needs to record // its own session. @@ -215,6 +215,10 @@ func OpenSource(t testingT, source Source, recordingName string) io.Closer { panic(r) } + if currentSession.verificationErr != nil { + t.Fatalf("%+v\n", currentSession.verificationErr.error) + } + currentSession.Close() currentSession = nil return nil diff --git a/copyist_test.go b/copyist_test.go index 38bb0a4..0e1c020 100644 --- a/copyist_test.go +++ b/copyist_test.go @@ -92,7 +92,7 @@ func (t *mockTestingT) Fatalf(format string, args ...interface{}) { fmt.Fprintf(&t.buf, format, args...) } -func TestSessionPanicsAreCaught(t *testing.T) { +func TestSessionFailuresAreFatalfd(t *testing.T) { // Enter playback mode. *recordFlag = false visitedRecording = true @@ -102,7 +102,7 @@ func TestSessionPanicsAreCaught(t *testing.T) { m := &mockTestingT{T: t} defer func() { - require.Equal(t, "no recording exists with this name: TestSessionPanicsAreCaught\n", + require.Equal(t, "no recording exists with this name: TestSessionFailuresAreFatalfd\n", m.buf.String()) }() @@ -110,8 +110,8 @@ func TestSessionPanicsAreCaught(t *testing.T) { db, err := sql.Open("copyist_postgres2", "") require.NoError(t, err) - // NB: This will panic, but the panic will be caught by the copyist closer and - // converted into a call to testing.T.Fatalf. + // NB: This will return an error, but the the copyist closer will track the + // first error and convert it into a call to testing.T.Fatalf. db.Query("SELECT 1") } diff --git a/driver.go b/driver.go index 013c8b7..6e647b8 100644 --- a/driver.go +++ b/driver.go @@ -125,8 +125,11 @@ func (d *proxyDriver) Open(name string) (driver.Conn, error) { return &proxyConn{driver: d, conn: conn, name: name, session: currentSession}, nil } - rec := currentSession.VerifyRecord(DriverOpen) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(DriverOpen) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } diff --git a/drivertest/commontest/common_test.go b/drivertest/commontest/common_test.go index d9fce46..f29e889 100644 --- a/drivertest/commontest/common_test.go +++ b/drivertest/commontest/common_test.go @@ -17,6 +17,7 @@ package commontest_test import ( "bytes" "database/sql" + "fmt" "io" "testing" @@ -94,6 +95,84 @@ func TestOpenReadWriteCloser(t *testing.T) { rows.Next() } +func TestRollbackWithRecover(t *testing.T) { + // This bug is only present in playback mode, short circuit if we're + // recording. + if copyist.IsRecording() { + return + } + + // This is a regression test for a deadlock when copyist would panic upon + // recording failures. We mount an intentionally out of date source that + // will fail on any action after opening our transaction. Our transaction + // helper will attempt a rollback in the case of an error or a panic, which + // would catch copyist's old behavior of panicking upon out of date + // recordings. + // We assert that we hit an out of date error and that rollback is called + // and returns. + source := CopyistSource(bytes.NewBuffer([]byte(` +1=DriverOpen 1:nil +2=ConnBegin 1:nil + +"TestRollbackWithRecover"=1,2`))) + + defer leaktest.Check(t)() + + var mt mockT + + closer := copyist.OpenSource(&mt, source, t.Name()) + + // Open database. + db, err := sql.Open("copyist_postgres", dataSourceName) + require.NoError(t, err) + defer db.Close() + + fnErr, txErr := execTransaction(db, func(tx *sql.Tx) error { + _, err := tx.Query("SELECT 1") + return err + }) + + require.EqualError(t, fnErr, "too many calls to ConnQuery\n\nDo you need to regenerate the recording with the -record flag?") + require.EqualError(t, txErr, "too many calls to TxRollback\n\nDo you need to regenerate the recording with the -record flag?") + + require.NoError(t, closer.Close()) // closer never errors. + + // Verify that the call to .Close invokes t.Fatalf with the first session + // error that we encountered. + require.Contains(t, mt.failure, "too many calls to ConnQuery") + // Verify that t.Fatalf includes the stacktrace leading to the call that + // triggered the first error. In this case, we look for the error coming + // from the first closure defined within this test function. + require.Contains(t, mt.failure, fmt.Sprintf("commontest_test.%s.func1", t.Name())) +} + +// execTransaction is a transaction helper function that attempts a rollback in +// the case of panics of errors. It returns both the closure error and the +// error of either commiting or rolling back. +// It is intended to mimic the behavior of +// https://github.com/cockroachdb/cockroach-go/blob/21a237074d6c3c35b68ec43e8d0c9e9ed714d21a/crdb/common.go#L38 +func execTransaction(db *sql.DB, fn func(*sql.Tx) error) (fnErr error, txErr error) { + tx, err := db.Begin() + if err != nil { + return nil, err + } + + defer func() { + if r := recover(); r != nil { + txErr = tx.Rollback() + panic(r) + } + + if fnErr == nil { + txErr = tx.Commit() + } else { + txErr = tx.Rollback() + } + }() + + return fn(tx), nil +} + func TestIsOpen(t *testing.T) { require.False(t, copyist.IsOpen()) @@ -164,3 +243,12 @@ func (s copyistSource) WriteAll([]byte) error { func CopyistSource(r io.Reader) copyist.Source { return copyistSource{r} } + +type mockT struct { + failure string +} + +func (mockT) Name() string { return "" } +func (t *mockT) Fatalf(format string, args ...interface{}) { + t.failure = fmt.Sprintf(format, args...) +} diff --git a/drivertest/pqtestold/go.sum b/drivertest/pqtestold/go.sum index b210669..3a090f9 100644 --- a/drivertest/pqtestold/go.sum +++ b/drivertest/pqtestold/go.sum @@ -81,6 +81,7 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/go.mod b/go.mod index ec39061..52bafe7 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,6 @@ require ( github.com/jackc/pgx/v4 v4.13.0 github.com/jmoiron/sqlx v1.3.4 github.com/lib/pq v1.10.2 + github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.7.0 ) diff --git a/result.go b/result.go index dbbfc58..557f6bc 100644 --- a/result.go +++ b/result.go @@ -34,8 +34,11 @@ func (r *proxyResult) LastInsertId() (int64, error) { return id, err } - rec := currentSession.VerifyRecord(ResultLastInsertId) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecord(ResultLastInsertId) + if err != nil { + return 0, err + } + err, _ = rec.Args[1].(error) if err != nil { return 0, err } @@ -51,8 +54,11 @@ func (r *proxyResult) RowsAffected() (int64, error) { return affected, err } - rec := currentSession.VerifyRecord(ResultRowsAffected) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecord(ResultRowsAffected) + if err != nil { + return 0, err + } + err, _ = rec.Args[1].(error) if err != nil { return 0, err } diff --git a/rows.go b/rows.go index d56aa84..d9bf3fe 100644 --- a/rows.go +++ b/rows.go @@ -35,7 +35,10 @@ func (r *proxyRows) Columns() []string { return cols } - rec := currentSession.VerifyRecord(RowsColumns) + rec, err := currentSession.VerifyRecord(RowsColumns) + if err != nil { + panic(err) + } return rec.Args[0].([]string) } @@ -70,8 +73,11 @@ func (r *proxyRows) Next(dest []driver.Value) error { return err } - rec := currentSession.VerifyRecord(RowsNext) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecord(RowsNext) + if err != nil { + return err + } + err, _ = rec.Args[1].(error) if err != nil { return err } diff --git a/session.go b/session.go index 58c3503..08824b0 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,8 @@ package copyist import ( "fmt" "os" + + "github.com/pkg/errors" ) // session is state used during copyist recording and playback to track progress @@ -39,6 +41,10 @@ type session struct { // isInit is set to true once this session has been initialized. isInit bool + + // verificationErr is the first sessionError encountered when replaying + // this session for better error reporting later on. + verificationErr *sessionError } // currentSession is a global instance of session that tracks state for the @@ -107,33 +113,36 @@ func (s *session) AddRecord(rec *record) { // VerifyRecordWithStringArg returns one of the records in this session's // recording, failing with a nice error if no such record exists, or if its // first argument does not match the given string. -func (s *session) VerifyRecordWithStringArg(recordTyp recordType, arg string) *record { - rec := s.VerifyRecord(recordTyp) +func (s *session) VerifyRecordWithStringArg(recordTyp recordType, arg string) (*record, error) { + rec, err := s.VerifyRecord(recordTyp) + if err != nil { + return nil, err + } if rec.Args[0].(string) != arg { - panicf( + return nil, s.sessionErr( "mismatched argument to %s, expected %s, got %s\n\n"+ "Do you need to regenerate the recording with the -record flag?", recordTyp.String(), rec.Args[0].(string), arg) } - return rec + return rec, nil } // VerifyRecord returns one of the records in this session's recording, failing // with a nice error if no such record exists. -func (s *session) VerifyRecord(recordTyp recordType) *record { +func (s *session) VerifyRecord(recordTyp recordType) (*record, error) { if s.index >= len(s.recording) { - panicf( + return nil, s.sessionErr( "too many calls to %s\n\n"+ "Do you need to regenerate the recording with the -record flag?", recordTyp.String()) } rec := s.recording[s.index] if rec.Typ != recordTyp { - panicf( + return nil, s.sessionErr( "unexpected call to %s\n\n"+ "Do you need to regenerate the recording with the -record flag?", recordTyp.String()) } s.index++ - return rec + return rec, nil } // Close ends this session, writing any recording file and clearing state. @@ -157,7 +166,15 @@ func (s *session) Close() { clearPooledConnections() } -func panicf(format string, args ...interface{}) { +func (s *session) sessionErr(format string, args ...interface{}) error { + err := &sessionError{errors.Errorf(format, args...)} + if s.verificationErr == nil { + s.verificationErr = err + } + return err +} + +func panicf(format string, args ...interface{}) error { panic(&sessionError{fmt.Errorf(format, args...)}) } diff --git a/stmt.go b/stmt.go index 42137b7..774db0e 100644 --- a/stmt.go +++ b/stmt.go @@ -56,7 +56,10 @@ func (s *proxyStmt) NumInput() int { return num } - rec := currentSession.VerifyRecord(StmtNumInput) + rec, err := currentSession.VerifyRecord(StmtNumInput) + if err != nil { + panic(err) + } return rec.Args[0].(int) } @@ -96,8 +99,11 @@ func (s *proxyStmt) ExecContext( return &proxyResult{res: res}, nil } - rec := currentSession.VerifyRecord(StmtExec) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(StmtExec) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } @@ -140,8 +146,11 @@ func (s *proxyStmt) QueryContext( return &proxyRows{rows: rows}, nil } - rec := currentSession.VerifyRecord(StmtQuery) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(StmtQuery) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } diff --git a/tx.go b/tx.go index e2f1e68..f73cd81 100644 --- a/tx.go +++ b/tx.go @@ -32,8 +32,11 @@ func (t *proxyTx) Commit() error { return err } - record := currentSession.VerifyRecord(TxCommit) - err, _ := record.Args[0].(error) + record, err := currentSession.VerifyRecord(TxCommit) + if err != nil { + return err + } + err, _ = record.Args[0].(error) return err } @@ -45,7 +48,10 @@ func (t *proxyTx) Rollback() error { return err } - record := currentSession.VerifyRecord(TxRollback) - err, _ := record.Args[0].(error) + record, err := currentSession.VerifyRecord(TxRollback) + if err != nil { + return err + } + err, _ = record.Args[0].(error) return err }