Skip to content

Commit 1abf7d9

Browse files
authored
Merge pull request #2240 from bonnefoa/fix-watch-panic
Unwatch and close connection on a batch write error
2 parents b5efc90 + 228cfff commit 1abf7d9

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

pgconn/pgconn.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -1773,19 +1773,21 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
17731773

17741774
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
17751775
if batch.err != nil {
1776+
pgConn.contextWatcher.Unwatch()
1777+
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
17761778
multiResult.closed = true
1777-
multiResult.err = batch.err
1778-
pgConn.unlock()
1779+
pgConn.asyncClose()
17791780
return multiResult
17801781
}
17811782

17821783
pgConn.enterPotentialWriteReadDeadlock()
17831784
defer pgConn.exitPotentialWriteReadDeadlock()
17841785
_, err := pgConn.conn.Write(batch.buf)
17851786
if err != nil {
1787+
pgConn.contextWatcher.Unwatch()
1788+
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
17861789
multiResult.closed = true
1787-
multiResult.err = err
1788-
pgConn.unlock()
1790+
pgConn.asyncClose()
17891791
return multiResult
17901792
}
17911793

pgconn/pgconn_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) {
14201420
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
14211421
}
14221422

1423+
type mockConnection struct {
1424+
net.Conn
1425+
writeLatency *time.Duration
1426+
}
1427+
1428+
func (m mockConnection) Write(b []byte) (n int, err error) {
1429+
time.Sleep(*m.writeLatency)
1430+
return m.Conn.Write(b)
1431+
}
1432+
1433+
func TestConnExecBatchWriteError(t *testing.T) {
1434+
t.Parallel()
1435+
1436+
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1437+
defer cancel()
1438+
1439+
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1440+
require.NoError(t, err)
1441+
1442+
var mockConn mockConnection
1443+
writeLatency := 0 * time.Second
1444+
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
1445+
conn, err := net.Dial(network, address)
1446+
mockConn = mockConnection{conn, &writeLatency}
1447+
return mockConn, err
1448+
}
1449+
1450+
pgConn, err := pgconn.ConnectConfig(ctx, config)
1451+
require.NoError(t, err)
1452+
defer closeConn(t, pgConn)
1453+
1454+
batch := &pgconn.Batch{}
1455+
pgConn.Conn()
1456+
1457+
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
1458+
defer cancel2()
1459+
1460+
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
1461+
writeLatency = 2 * time.Second
1462+
mrr := pgConn.ExecBatch(ctx2, batch)
1463+
err = mrr.Close()
1464+
require.Error(t, err)
1465+
assert.ErrorIs(t, err, context.DeadlineExceeded)
1466+
require.True(t, pgConn.IsClosed())
1467+
}
1468+
14231469
func TestConnExecBatchDeferredError(t *testing.T) {
14241470
t.Parallel()
14251471

0 commit comments

Comments
 (0)