Skip to content

Commit 4a27a75

Browse files
committed
Add WithTXOption expectation to ExpectBegin
1 parent 6bed17c commit 4a27a75

File tree

4 files changed

+85
-5
lines changed

4 files changed

+85
-5
lines changed

expectations.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sqlmock
22

33
import (
4+
"database/sql"
45
"database/sql/driver"
56
"fmt"
67
"strings"
@@ -53,7 +54,8 @@ func (e *ExpectedClose) String() string {
5354
// returned by *Sqlmock.ExpectBegin.
5455
type ExpectedBegin struct {
5556
commonExpectation
56-
delay time.Duration
57+
delay time.Duration
58+
txOpts *driver.TxOptions
5759
}
5860

5961
// WillReturnError allows to set an error for *sql.DB.Begin action
@@ -65,6 +67,9 @@ func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
6567
// String returns string representation
6668
func (e *ExpectedBegin) String() string {
6769
msg := "ExpectedBegin => expecting database transaction Begin"
70+
if e.txOpts != nil {
71+
msg += fmt.Sprintf(", with tx options: %+v", e.txOpts)
72+
}
6873
if e.err != nil {
6974
msg += fmt.Sprintf(", which should return error: %s", e.err)
7075
}
@@ -78,6 +83,15 @@ func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
7883
return e
7984
}
8085

86+
// WithTxOptions allows to set transaction options for *sql.DB.Begin action
87+
func (e *ExpectedBegin) WithTxOptions(opts sql.TxOptions) *ExpectedBegin {
88+
e.txOpts = &driver.TxOptions{
89+
Isolation: driver.IsolationLevel(opts.Isolation),
90+
ReadOnly: opts.ReadOnly,
91+
}
92+
return e
93+
}
94+
8195
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
8296
// returned by *Sqlmock.ExpectCommit.
8397
type ExpectedCommit struct {

sqlmock.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ func (c *sqlmock) ExpectationsWereMet() error {
213213

214214
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
215215
func (c *sqlmock) Begin() (driver.Tx, error) {
216-
ex, err := c.begin()
216+
ex, err := c.begin(driver.TxOptions{})
217217
if ex != nil {
218218
time.Sleep(ex.delay)
219219
}
@@ -224,7 +224,7 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
224224
return c, nil
225225
}
226226

227-
func (c *sqlmock) begin() (*ExpectedBegin, error) {
227+
func (c *sqlmock) begin(opts driver.TxOptions) (*ExpectedBegin, error) {
228228
var expected *ExpectedBegin
229229
var ok bool
230230
var fulfilled int
@@ -252,9 +252,14 @@ func (c *sqlmock) begin() (*ExpectedBegin, error) {
252252
}
253253
return nil, fmt.Errorf(msg)
254254
}
255+
defer expected.Unlock()
256+
if expected.txOpts != nil &&
257+
expected.txOpts.Isolation != opts.Isolation &&
258+
expected.txOpts.ReadOnly != opts.ReadOnly {
259+
return nil, fmt.Errorf("expected transaction options do not match: %+v, got: %+v", expected.txOpts, opts)
260+
}
255261

256262
expected.triggered = true
257-
expected.Unlock()
258263

259264
return expected, expected.err
260265
}

sqlmock_go18.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
//go:build go1.8
12
// +build go1.8
23

34
package sqlmock
@@ -66,7 +67,7 @@ func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.N
6667

6768
// Implement the "ConnBeginTx" interface
6869
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
69-
ex, err := c.begin()
70+
ex, err := c.begin(opts)
7071
if ex != nil {
7172
select {
7273
case <-time.After(ex.delay):

sqlmock_go18_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,66 @@ func TestContextBegin(t *testing.T) {
360360
}
361361
}
362362

363+
func TestContextBeginWithTxOptions(t *testing.T) {
364+
t.Parallel()
365+
db, mock, err := New()
366+
if err != nil {
367+
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
368+
}
369+
defer db.Close()
370+
371+
mock.ExpectBegin().WithTxOptions(sql.TxOptions{
372+
Isolation: sql.LevelReadCommitted,
373+
ReadOnly: true,
374+
})
375+
376+
ctx, cancel := context.WithCancel(context.Background())
377+
378+
go func() {
379+
time.Sleep(time.Millisecond * 10)
380+
cancel()
381+
}()
382+
383+
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: false})
384+
if err != nil {
385+
t.Errorf("error was not expected, but got: %v", err)
386+
}
387+
388+
if err := mock.ExpectationsWereMet(); err != nil {
389+
t.Errorf("there were unfulfilled expectations: %s", err)
390+
}
391+
}
392+
393+
func TestContextBeginWithTxOptionsMismatch(t *testing.T) {
394+
t.Parallel()
395+
db, mock, err := New()
396+
if err != nil {
397+
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
398+
}
399+
defer db.Close()
400+
401+
mock.ExpectBegin().WithTxOptions(sql.TxOptions{
402+
Isolation: sql.LevelReadCommitted,
403+
ReadOnly: true,
404+
})
405+
406+
ctx, cancel := context.WithCancel(context.Background())
407+
408+
go func() {
409+
time.Sleep(time.Millisecond * 10)
410+
cancel()
411+
}()
412+
413+
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelDefault, ReadOnly: false})
414+
if err == nil {
415+
t.Error("error was expected, but there was none")
416+
}
417+
418+
if err := mock.ExpectationsWereMet(); err == nil {
419+
t.Errorf("was expecting an error, as the tx options did not match, but there wasn't one")
420+
}
421+
}
422+
363423
func TestContextPrepareCancel(t *testing.T) {
364424
t.Parallel()
365425
db, mock, err := New()

0 commit comments

Comments
 (0)