diff --git a/conn.go b/conn.go index 984c2c8..239024e 100644 --- a/conn.go +++ b/conn.go @@ -97,8 +97,11 @@ func (c *proxyConn) ExecContext( return &proxyResult{res: res}, nil } - rec := currentSession.VerifyRecordWithStringArg(ConnExec, query) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecordWithStringArg(ConnExec, query) + if err != nil { + return nil, err + } + err, _ = rec.Args[1].(error) if err != nil { return nil, err } @@ -130,8 +133,11 @@ func (c *proxyConn) PrepareContext(ctx context.Context, query string) (driver.St return &proxyStmt{stmt: stmt}, nil } - rec := currentSession.VerifyRecordWithStringArg(ConnPrepare, query) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecordWithStringArg(ConnPrepare, query) + if err != nil { + return nil, err + } + err, _ = rec.Args[1].(error) if err != nil { return nil, err } @@ -169,8 +175,11 @@ func (c *proxyConn) QueryContext( return &proxyRows{rows: rows}, nil } - rec := currentSession.VerifyRecordWithStringArg(ConnQuery, query) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecordWithStringArg(ConnQuery, query) + if err != nil { + return nil, err + } + err, _ = rec.Args[1].(error) if err != nil { return nil, err } @@ -232,8 +241,11 @@ func (c *proxyConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver. return &proxyTx{tx: tx}, nil } - rec := currentSession.VerifyRecord(ConnBegin) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(ConnBegin) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } diff --git a/copyist.go b/copyist.go index 00098ed..73b0aa3 100644 --- a/copyist.go +++ b/copyist.go @@ -94,7 +94,7 @@ var registered map[string]*proxyDriver // The Register method takes the name of the SQL driver to be wrapped (e.g. // "postgres"). Below is an example of how copyist.Register should be invoked. // -// copyist.Register("postgres") +// copyist.Register("postgres") // // Note that Register can only be called once for a given driver; subsequent // attempts will fail with an error. In addition, the same copyist driver must @@ -139,22 +139,22 @@ func SetSessionInit(callback SessionInitCallback) { // alongside the calling test file. If playing back, then the recording will // be fetched from that recording file. Here is a typical calling pattern: // -// func init() { -// copyist.Register("postgres") -// } +// func init() { +// copyist.Register("postgres") +// } // -// func TestMyStuff(t *testing.T) { -// defer copyist.Open(t).Close() -// ... -// } +// func TestMyStuff(t *testing.T) { +// defer copyist.Open(t).Close() +// ... +// } // // The call to Open will initiate a new recording session. The deferred call to // Close will complete the recording session and write the recording to a file // in the testdata/ directory, like: // -// mystuff_test.go -// testdata/ -// mystuff_test.copyist +// mystuff_test.go +// testdata/ +// mystuff_test.copyist // // Each test or sub-test that needs to be executed independently needs to record // its own session. @@ -215,6 +215,10 @@ func OpenSource(t testingT, source Source, recordingName string) io.Closer { panic(r) } + if currentSession.verificationErr != nil { + t.Fatalf("%+v\n", currentSession.verificationErr.error) + } + currentSession.Close() currentSession = nil return nil diff --git a/copyist_test.go b/copyist_test.go index 38bb0a4..0e1c020 100644 --- a/copyist_test.go +++ b/copyist_test.go @@ -92,7 +92,7 @@ func (t *mockTestingT) Fatalf(format string, args ...interface{}) { fmt.Fprintf(&t.buf, format, args...) } -func TestSessionPanicsAreCaught(t *testing.T) { +func TestSessionFailuresAreFatalfd(t *testing.T) { // Enter playback mode. *recordFlag = false visitedRecording = true @@ -102,7 +102,7 @@ func TestSessionPanicsAreCaught(t *testing.T) { m := &mockTestingT{T: t} defer func() { - require.Equal(t, "no recording exists with this name: TestSessionPanicsAreCaught\n", + require.Equal(t, "no recording exists with this name: TestSessionFailuresAreFatalfd\n", m.buf.String()) }() @@ -110,8 +110,8 @@ func TestSessionPanicsAreCaught(t *testing.T) { db, err := sql.Open("copyist_postgres2", "") require.NoError(t, err) - // NB: This will panic, but the panic will be caught by the copyist closer and - // converted into a call to testing.T.Fatalf. + // NB: This will return an error, but the the copyist closer will track the + // first error and convert it into a call to testing.T.Fatalf. db.Query("SELECT 1") } diff --git a/driver.go b/driver.go index 013c8b7..6e647b8 100644 --- a/driver.go +++ b/driver.go @@ -125,8 +125,11 @@ func (d *proxyDriver) Open(name string) (driver.Conn, error) { return &proxyConn{driver: d, conn: conn, name: name, session: currentSession}, nil } - rec := currentSession.VerifyRecord(DriverOpen) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(DriverOpen) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } diff --git a/drivertest/commontest/common_test.go b/drivertest/commontest/common_test.go index d9fce46..f29e889 100644 --- a/drivertest/commontest/common_test.go +++ b/drivertest/commontest/common_test.go @@ -17,6 +17,7 @@ package commontest_test import ( "bytes" "database/sql" + "fmt" "io" "testing" @@ -94,6 +95,84 @@ func TestOpenReadWriteCloser(t *testing.T) { rows.Next() } +func TestRollbackWithRecover(t *testing.T) { + // This bug is only present in playback mode, short circuit if we're + // recording. + if copyist.IsRecording() { + return + } + + // This is a regression test for a deadlock when copyist would panic upon + // recording failures. We mount an intentionally out of date source that + // will fail on any action after opening our transaction. Our transaction + // helper will attempt a rollback in the case of an error or a panic, which + // would catch copyist's old behavior of panicking upon out of date + // recordings. + // We assert that we hit an out of date error and that rollback is called + // and returns. + source := CopyistSource(bytes.NewBuffer([]byte(` +1=DriverOpen 1:nil +2=ConnBegin 1:nil + +"TestRollbackWithRecover"=1,2`))) + + defer leaktest.Check(t)() + + var mt mockT + + closer := copyist.OpenSource(&mt, source, t.Name()) + + // Open database. + db, err := sql.Open("copyist_postgres", dataSourceName) + require.NoError(t, err) + defer db.Close() + + fnErr, txErr := execTransaction(db, func(tx *sql.Tx) error { + _, err := tx.Query("SELECT 1") + return err + }) + + require.EqualError(t, fnErr, "too many calls to ConnQuery\n\nDo you need to regenerate the recording with the -record flag?") + require.EqualError(t, txErr, "too many calls to TxRollback\n\nDo you need to regenerate the recording with the -record flag?") + + require.NoError(t, closer.Close()) // closer never errors. + + // Verify that the call to .Close invokes t.Fatalf with the first session + // error that we encountered. + require.Contains(t, mt.failure, "too many calls to ConnQuery") + // Verify that t.Fatalf includes the stacktrace leading to the call that + // triggered the first error. In this case, we look for the error coming + // from the first closure defined within this test function. + require.Contains(t, mt.failure, fmt.Sprintf("commontest_test.%s.func1", t.Name())) +} + +// execTransaction is a transaction helper function that attempts a rollback in +// the case of panics of errors. It returns both the closure error and the +// error of either commiting or rolling back. +// It is intended to mimic the behavior of +// https://github.com/cockroachdb/cockroach-go/blob/21a237074d6c3c35b68ec43e8d0c9e9ed714d21a/crdb/common.go#L38 +func execTransaction(db *sql.DB, fn func(*sql.Tx) error) (fnErr error, txErr error) { + tx, err := db.Begin() + if err != nil { + return nil, err + } + + defer func() { + if r := recover(); r != nil { + txErr = tx.Rollback() + panic(r) + } + + if fnErr == nil { + txErr = tx.Commit() + } else { + txErr = tx.Rollback() + } + }() + + return fn(tx), nil +} + func TestIsOpen(t *testing.T) { require.False(t, copyist.IsOpen()) @@ -164,3 +243,12 @@ func (s copyistSource) WriteAll([]byte) error { func CopyistSource(r io.Reader) copyist.Source { return copyistSource{r} } + +type mockT struct { + failure string +} + +func (mockT) Name() string { return "" } +func (t *mockT) Fatalf(format string, args ...interface{}) { + t.failure = fmt.Sprintf(format, args...) +} diff --git a/drivertest/pqtestold/go.sum b/drivertest/pqtestold/go.sum index b210669..3a090f9 100644 --- a/drivertest/pqtestold/go.sum +++ b/drivertest/pqtestold/go.sum @@ -81,6 +81,7 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/go.mod b/go.mod index ec39061..52bafe7 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,6 @@ require ( github.com/jackc/pgx/v4 v4.13.0 github.com/jmoiron/sqlx v1.3.4 github.com/lib/pq v1.10.2 + github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.7.0 ) diff --git a/result.go b/result.go index dbbfc58..557f6bc 100644 --- a/result.go +++ b/result.go @@ -34,8 +34,11 @@ func (r *proxyResult) LastInsertId() (int64, error) { return id, err } - rec := currentSession.VerifyRecord(ResultLastInsertId) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecord(ResultLastInsertId) + if err != nil { + return 0, err + } + err, _ = rec.Args[1].(error) if err != nil { return 0, err } @@ -51,8 +54,11 @@ func (r *proxyResult) RowsAffected() (int64, error) { return affected, err } - rec := currentSession.VerifyRecord(ResultRowsAffected) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecord(ResultRowsAffected) + if err != nil { + return 0, err + } + err, _ = rec.Args[1].(error) if err != nil { return 0, err } diff --git a/rows.go b/rows.go index d56aa84..d9bf3fe 100644 --- a/rows.go +++ b/rows.go @@ -35,7 +35,10 @@ func (r *proxyRows) Columns() []string { return cols } - rec := currentSession.VerifyRecord(RowsColumns) + rec, err := currentSession.VerifyRecord(RowsColumns) + if err != nil { + panic(err) + } return rec.Args[0].([]string) } @@ -70,8 +73,11 @@ func (r *proxyRows) Next(dest []driver.Value) error { return err } - rec := currentSession.VerifyRecord(RowsNext) - err, _ := rec.Args[1].(error) + rec, err := currentSession.VerifyRecord(RowsNext) + if err != nil { + return err + } + err, _ = rec.Args[1].(error) if err != nil { return err } diff --git a/session.go b/session.go index 58c3503..08824b0 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,8 @@ package copyist import ( "fmt" "os" + + "github.com/pkg/errors" ) // session is state used during copyist recording and playback to track progress @@ -39,6 +41,10 @@ type session struct { // isInit is set to true once this session has been initialized. isInit bool + + // verificationErr is the first sessionError encountered when replaying + // this session for better error reporting later on. + verificationErr *sessionError } // currentSession is a global instance of session that tracks state for the @@ -107,33 +113,36 @@ func (s *session) AddRecord(rec *record) { // VerifyRecordWithStringArg returns one of the records in this session's // recording, failing with a nice error if no such record exists, or if its // first argument does not match the given string. -func (s *session) VerifyRecordWithStringArg(recordTyp recordType, arg string) *record { - rec := s.VerifyRecord(recordTyp) +func (s *session) VerifyRecordWithStringArg(recordTyp recordType, arg string) (*record, error) { + rec, err := s.VerifyRecord(recordTyp) + if err != nil { + return nil, err + } if rec.Args[0].(string) != arg { - panicf( + return nil, s.sessionErr( "mismatched argument to %s, expected %s, got %s\n\n"+ "Do you need to regenerate the recording with the -record flag?", recordTyp.String(), rec.Args[0].(string), arg) } - return rec + return rec, nil } // VerifyRecord returns one of the records in this session's recording, failing // with a nice error if no such record exists. -func (s *session) VerifyRecord(recordTyp recordType) *record { +func (s *session) VerifyRecord(recordTyp recordType) (*record, error) { if s.index >= len(s.recording) { - panicf( + return nil, s.sessionErr( "too many calls to %s\n\n"+ "Do you need to regenerate the recording with the -record flag?", recordTyp.String()) } rec := s.recording[s.index] if rec.Typ != recordTyp { - panicf( + return nil, s.sessionErr( "unexpected call to %s\n\n"+ "Do you need to regenerate the recording with the -record flag?", recordTyp.String()) } s.index++ - return rec + return rec, nil } // Close ends this session, writing any recording file and clearing state. @@ -157,7 +166,15 @@ func (s *session) Close() { clearPooledConnections() } -func panicf(format string, args ...interface{}) { +func (s *session) sessionErr(format string, args ...interface{}) error { + err := &sessionError{errors.Errorf(format, args...)} + if s.verificationErr == nil { + s.verificationErr = err + } + return err +} + +func panicf(format string, args ...interface{}) error { panic(&sessionError{fmt.Errorf(format, args...)}) } diff --git a/stmt.go b/stmt.go index 42137b7..774db0e 100644 --- a/stmt.go +++ b/stmt.go @@ -56,7 +56,10 @@ func (s *proxyStmt) NumInput() int { return num } - rec := currentSession.VerifyRecord(StmtNumInput) + rec, err := currentSession.VerifyRecord(StmtNumInput) + if err != nil { + panic(err) + } return rec.Args[0].(int) } @@ -96,8 +99,11 @@ func (s *proxyStmt) ExecContext( return &proxyResult{res: res}, nil } - rec := currentSession.VerifyRecord(StmtExec) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(StmtExec) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } @@ -140,8 +146,11 @@ func (s *proxyStmt) QueryContext( return &proxyRows{rows: rows}, nil } - rec := currentSession.VerifyRecord(StmtQuery) - err, _ := rec.Args[0].(error) + rec, err := currentSession.VerifyRecord(StmtQuery) + if err != nil { + return nil, err + } + err, _ = rec.Args[0].(error) if err != nil { return nil, err } diff --git a/tx.go b/tx.go index e2f1e68..f73cd81 100644 --- a/tx.go +++ b/tx.go @@ -32,8 +32,11 @@ func (t *proxyTx) Commit() error { return err } - record := currentSession.VerifyRecord(TxCommit) - err, _ := record.Args[0].(error) + record, err := currentSession.VerifyRecord(TxCommit) + if err != nil { + return err + } + err, _ = record.Args[0].(error) return err } @@ -45,7 +48,10 @@ func (t *proxyTx) Rollback() error { return err } - record := currentSession.VerifyRecord(TxRollback) - err, _ := record.Args[0].(error) + record, err := currentSession.VerifyRecord(TxRollback) + if err != nil { + return err + } + err, _ = record.Args[0].(error) return err }