Skip to content

Commit 2c1b1c3

Browse files
authoredJan 11, 2025··
Merge pull request #2200 from zenkovev/flush_request_in_pipeline
add flush request in pipeline
2 parents e877606 + c96a55f commit 2c1b1c3

File tree

2 files changed

+515
-40
lines changed

2 files changed

+515
-40
lines changed
 

‎pgconn/pgconn.go

+177-40
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package pgconn
22

33
import (
4+
"container/list"
45
"context"
56
"crypto/md5"
67
"crypto/tls"
@@ -1408,9 +1409,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
14081409

14091410
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
14101411
type MultiResultReader struct {
1411-
pgConn *PgConn
1412-
ctx context.Context
1413-
pipeline *Pipeline
1412+
pgConn *PgConn
1413+
ctx context.Context
14141414

14151415
rr *ResultReader
14161416

@@ -1443,12 +1443,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
14431443
switch msg := msg.(type) {
14441444
case *pgproto3.ReadyForQuery:
14451445
mrr.closed = true
1446-
if mrr.pipeline != nil {
1447-
mrr.pipeline.expectedReadyForQueryCount--
1448-
} else {
1449-
mrr.pgConn.contextWatcher.Unwatch()
1450-
mrr.pgConn.unlock()
1451-
}
1446+
mrr.pgConn.contextWatcher.Unwatch()
1447+
mrr.pgConn.unlock()
14521448
case *pgproto3.ErrorResponse:
14531449
mrr.err = ErrorResponseToPgError(msg)
14541450
}
@@ -1672,7 +1668,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
16721668
case *pgproto3.EmptyQueryResponse:
16731669
rr.concludeCommand(CommandTag{}, nil)
16741670
case *pgproto3.ErrorResponse:
1675-
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
1671+
pgErr := ErrorResponseToPgError(msg)
1672+
if rr.pipeline != nil {
1673+
rr.pipeline.state.HandleError(pgErr)
1674+
}
1675+
rr.concludeCommand(CommandTag{}, pgErr)
16761676
}
16771677

16781678
return msg, nil
@@ -1999,9 +1999,7 @@ type Pipeline struct {
19991999
conn *PgConn
20002000
ctx context.Context
20012001

2002-
expectedReadyForQueryCount int
2003-
pendingSync bool
2004-
2002+
state pipelineState
20052003
err error
20062004
closed bool
20072005
}
@@ -2012,6 +2010,122 @@ type PipelineSync struct{}
20122010
// CloseComplete is returned by GetResults when a CloseComplete message is received.
20132011
type CloseComplete struct{}
20142012

2013+
type pipelineRequestType int
2014+
2015+
const (
2016+
pipelineNil pipelineRequestType = iota
2017+
pipelinePrepare
2018+
pipelineQueryParams
2019+
pipelineQueryPrepared
2020+
pipelineDeallocate
2021+
pipelineSyncRequest
2022+
pipelineFlushRequest
2023+
)
2024+
2025+
type pipelineRequestEvent struct {
2026+
RequestType pipelineRequestType
2027+
WasSentToServer bool
2028+
BeforeFlushOrSync bool
2029+
}
2030+
2031+
type pipelineState struct {
2032+
requestEventQueue list.List
2033+
lastRequestType pipelineRequestType
2034+
pgErr *PgError
2035+
expectedReadyForQueryCount int
2036+
}
2037+
2038+
func (s *pipelineState) Init() {
2039+
s.requestEventQueue.Init()
2040+
s.lastRequestType = pipelineNil
2041+
}
2042+
2043+
func (s *pipelineState) RegisterSendingToServer() {
2044+
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
2045+
val := elem.Value.(pipelineRequestEvent)
2046+
if val.WasSentToServer {
2047+
return
2048+
}
2049+
val.WasSentToServer = true
2050+
elem.Value = val
2051+
}
2052+
}
2053+
2054+
func (s *pipelineState) registerFlushingBufferOnServer() {
2055+
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
2056+
val := elem.Value.(pipelineRequestEvent)
2057+
if val.BeforeFlushOrSync {
2058+
return
2059+
}
2060+
val.BeforeFlushOrSync = true
2061+
elem.Value = val
2062+
}
2063+
}
2064+
2065+
func (s *pipelineState) PushBackRequestType(req pipelineRequestType) {
2066+
if req == pipelineNil {
2067+
return
2068+
}
2069+
2070+
if req != pipelineFlushRequest {
2071+
s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req})
2072+
}
2073+
if req == pipelineFlushRequest || req == pipelineSyncRequest {
2074+
s.registerFlushingBufferOnServer()
2075+
}
2076+
s.lastRequestType = req
2077+
2078+
if req == pipelineSyncRequest {
2079+
s.expectedReadyForQueryCount++
2080+
}
2081+
}
2082+
2083+
func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType {
2084+
for {
2085+
elem := s.requestEventQueue.Front()
2086+
if elem == nil {
2087+
return pipelineNil
2088+
}
2089+
val := elem.Value.(pipelineRequestEvent)
2090+
if !(val.WasSentToServer && val.BeforeFlushOrSync) {
2091+
return pipelineNil
2092+
}
2093+
2094+
s.requestEventQueue.Remove(elem)
2095+
if val.RequestType == pipelineSyncRequest {
2096+
s.pgErr = nil
2097+
}
2098+
if s.pgErr == nil {
2099+
return val.RequestType
2100+
}
2101+
}
2102+
}
2103+
2104+
func (s *pipelineState) HandleError(err *PgError) {
2105+
s.pgErr = err
2106+
}
2107+
2108+
func (s *pipelineState) HandleReadyForQuery() {
2109+
s.expectedReadyForQueryCount--
2110+
}
2111+
2112+
func (s *pipelineState) PendingSync() bool {
2113+
var notPendingSync bool
2114+
2115+
if elem := s.requestEventQueue.Back(); elem != nil {
2116+
val := elem.Value.(pipelineRequestEvent)
2117+
notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer
2118+
} else {
2119+
notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil)
2120+
}
2121+
2122+
return !notPendingSync
2123+
}
2124+
2125+
func (s *pipelineState) ExpectedReadyForQuery() int {
2126+
return s.expectedReadyForQueryCount
2127+
}
2128+
20152129
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
20162130
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
20172131
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
@@ -2020,16 +2134,21 @@ type CloseComplete struct{}
20202134
// Prefer ExecBatch when only sending one group of queries at once.
20212135
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
20222136
if err := pgConn.lock(); err != nil {
2023-
return &Pipeline{
2137+
pipeline := &Pipeline{
20242138
closed: true,
20252139
err: err,
20262140
}
2141+
pipeline.state.Init()
2142+
2143+
return pipeline
20272144
}
20282145

20292146
pgConn.pipeline = Pipeline{
20302147
conn: pgConn,
20312148
ctx: ctx,
20322149
}
2150+
pgConn.pipeline.state.Init()
2151+
20332152
pipeline := &pgConn.pipeline
20342153

20352154
if ctx != context.Background() {
@@ -2052,45 +2171,76 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
20522171
if p.closed {
20532172
return
20542173
}
2055-
p.pendingSync = true
20562174

20572175
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
20582176
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
2177+
p.state.PushBackRequestType(pipelinePrepare)
20592178
}
20602179

20612180
// SendDeallocate deallocates a prepared statement.
20622181
func (p *Pipeline) SendDeallocate(name string) {
20632182
if p.closed {
20642183
return
20652184
}
2066-
p.pendingSync = true
20672185

20682186
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
2187+
p.state.PushBackRequestType(pipelineDeallocate)
20692188
}
20702189

20712190
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
20722191
func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
20732192
if p.closed {
20742193
return
20752194
}
2076-
p.pendingSync = true
20772195

20782196
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
20792197
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
20802198
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
20812199
p.conn.frontend.SendExecute(&pgproto3.Execute{})
2200+
p.state.PushBackRequestType(pipelineQueryParams)
20822201
}
20832202

20842203
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
20852204
func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
20862205
if p.closed {
20872206
return
20882207
}
2089-
p.pendingSync = true
20902208

20912209
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
20922210
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
20932211
p.conn.frontend.SendExecute(&pgproto3.Execute{})
2212+
p.state.PushBackRequestType(pipelineQueryPrepared)
2213+
}
2214+
2215+
// SendFlushRequest sends a request for the server to flush its output buffer.
2216+
//
2217+
// The server flushes its output buffer automatically as a result of Sync being called,
2218+
// or on any request when not in pipeline mode; this function is useful to cause the server
2219+
// to flush its output buffer in pipeline mode without establishing a synchronization point.
2220+
// Note that the request is not itself flushed to the server automatically; use Flush if
2221+
// necessary. This copies the behavior of libpq PQsendFlushRequest.
2222+
func (p *Pipeline) SendFlushRequest() {
2223+
if p.closed {
2224+
return
2225+
}
2226+
2227+
p.conn.frontend.Send(&pgproto3.Flush{})
2228+
p.state.PushBackRequestType(pipelineFlushRequest)
2229+
}
2230+
2231+
// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message
2232+
// without flushing the send buffer. This serves as the delimiter of an implicit
2233+
// transaction and an error recovery point.
2234+
//
2235+
// Note that the request is not itself flushed to the server automatically; use Flush if
2236+
// necessary. This copies the behavior of libpq PQsendPipelineSync.
2237+
func (p *Pipeline) SendPipelineSync() {
2238+
if p.closed {
2239+
return
2240+
}
2241+
2242+
p.conn.frontend.SendSync(&pgproto3.Sync{})
2243+
p.state.PushBackRequestType(pipelineSyncRequest)
20942244
}
20952245

20962246
// Flush flushes the queued requests without establishing a synchronization point.
@@ -2115,28 +2265,14 @@ func (p *Pipeline) Flush() error {
21152265
return err
21162266
}
21172267

2268+
p.state.RegisterSendingToServer()
21182269
return nil
21192270
}
21202271

21212272
// Sync establishes a synchronization point and flushes the queued requests.
21222273
func (p *Pipeline) Sync() error {
2123-
if p.closed {
2124-
if p.err != nil {
2125-
return p.err
2126-
}
2127-
return errors.New("pipeline closed")
2128-
}
2129-
2130-
p.conn.frontend.SendSync(&pgproto3.Sync{})
2131-
err := p.Flush()
2132-
if err != nil {
2133-
return err
2134-
}
2135-
2136-
p.pendingSync = false
2137-
p.expectedReadyForQueryCount++
2138-
2139-
return nil
2274+
p.SendPipelineSync()
2275+
return p.Flush()
21402276
}
21412277

21422278
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
@@ -2150,7 +2286,7 @@ func (p *Pipeline) GetResults() (results any, err error) {
21502286
return nil, errors.New("pipeline closed")
21512287
}
21522288

2153-
if p.expectedReadyForQueryCount == 0 {
2289+
if p.state.ExtractFrontRequestType() == pipelineNil {
21542290
return nil, nil
21552291
}
21562292

@@ -2195,13 +2331,13 @@ func (p *Pipeline) getResults() (results any, err error) {
21952331
case *pgproto3.CloseComplete:
21962332
return &CloseComplete{}, nil
21972333
case *pgproto3.ReadyForQuery:
2198-
p.expectedReadyForQueryCount--
2334+
p.state.HandleReadyForQuery()
21992335
return &PipelineSync{}, nil
22002336
case *pgproto3.ErrorResponse:
22012337
pgErr := ErrorResponseToPgError(msg)
2338+
p.state.HandleError(pgErr)
22022339
return nil, pgErr
22032340
}
2204-
22052341
}
22062342
}
22072343

@@ -2231,6 +2367,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
22312367
// These should never happen here. But don't take chances that could lead to a deadlock.
22322368
case *pgproto3.ErrorResponse:
22332369
pgErr := ErrorResponseToPgError(msg)
2370+
p.state.HandleError(pgErr)
22342371
return nil, pgErr
22352372
case *pgproto3.CommandComplete:
22362373
p.conn.asyncClose()
@@ -2250,7 +2387,7 @@ func (p *Pipeline) Close() error {
22502387

22512388
p.closed = true
22522389

2253-
if p.pendingSync {
2390+
if p.state.PendingSync() {
22542391
p.conn.asyncClose()
22552392
p.err = errors.New("pipeline has unsynced requests")
22562393
p.conn.contextWatcher.Unwatch()
@@ -2259,7 +2396,7 @@ func (p *Pipeline) Close() error {
22592396
return p.err
22602397
}
22612398

2262-
for p.expectedReadyForQueryCount > 0 {
2399+
for p.state.ExpectedReadyForQuery() > 0 {
22632400
_, err := p.getResults()
22642401
if err != nil {
22652402
p.err = err

0 commit comments

Comments
 (0)
Please sign in to comment.