From 3c493b214e73a3b875cca78e0779738c648829cb Mon Sep 17 00:00:00 2001 From: djshow832 Date: Mon, 15 Jan 2024 10:44:48 +0800 Subject: [PATCH] backend, logger: fix data race in UT (#443) --- pkg/manager/logger/manager_test.go | 59 +++++++------- pkg/proxy/backend/authenticator.go | 7 +- pkg/proxy/backend/backend_conn_mgr.go | 78 +++++++++++------- pkg/proxy/backend/backend_conn_mgr_test.go | 93 +++++++++++++++++++--- pkg/proxy/backend/error.go | 4 + pkg/proxy/backend/handshake_handler.go | 4 + pkg/proxy/backend/mock_proxy_test.go | 8 +- pkg/proxy/proxy.go | 7 ++ 8 files changed, 188 insertions(+), 72 deletions(-) diff --git a/pkg/manager/logger/manager_test.go b/pkg/manager/logger/manager_test.go index 52e7fc68..a86c9eea 100644 --- a/pkg/manager/logger/manager_test.go +++ b/pkg/manager/logger/manager_test.go @@ -184,49 +184,54 @@ func readLogFiles(t *testing.T, dir string) []os.FileInfo { func TestLogConcurrently(t *testing.T) { dir := t.TempDir() fileName := filepath.Join(dir, "proxy.log") - cfg := &config.Config{ - Log: config.Log{ - Encoder: "tidb", - LogOnline: config.LogOnline{ - Level: "info", - LogFile: config.LogFile{ - Filename: fileName, - MaxSize: 1, - MaxDays: 2, - MaxBackups: 3, + + newCfg := func() *config.Config { + return &config.Config{ + Log: config.Log{ + Encoder: "tidb", + LogOnline: config.LogOnline{ + Level: "info", + LogFile: config.LogFile{ + Filename: fileName, + MaxSize: 1, + MaxDays: 2, + MaxBackups: 3, + }, }, }, - }, + } } - lg, ch := setupLogManager(t, cfg) + lg, ch := setupLogManager(t, newCfg()) var wg waitgroup.WaitGroup ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) for i := 0; i < 5; i++ { wg.Run(func() { for ctx.Err() == nil { - lg = lg.Named("test_name") - lg.Info("test_info") - lg.Warn("test_warn") - lg.Error("test_error") - lg = lg.With(zap.String("with", "test_with")) - lg.Info("test_info") - lg.Warn("test_warn") - lg.Error("test_error") + namedLg := lg.Named("test_name") + namedLg.Info("test_info") + namedLg.Warn("test_warn") + namedLg.Error("test_error") + withLg := namedLg.With(zap.String("with", "test_with")) + withLg.Info("test_info") + withLg.Warn("test_warn") + withLg.Error("test_error") } }) } wg.Run(func() { - newCfg := cfg.Clone() for ctx.Err() == nil { - newCfg.Log.LogFile.MaxDays = int(rand.Int31n(10)) - ch <- newCfg + cfg := newCfg() + cfg.Log.LogFile.MaxDays = int(rand.Int31n(10)) + ch <- cfg time.Sleep(10 * time.Millisecond) - newCfg.Log.LogFile.MaxBackups = int(rand.Int31n(10)) - ch <- newCfg + cfg = newCfg() + cfg.Log.LogFile.MaxBackups = int(rand.Int31n(10)) + ch <- cfg time.Sleep(10 * time.Millisecond) - newCfg.Log.LogFile.MaxSize = int(rand.Int31n(10)) - ch <- newCfg + cfg = newCfg() + cfg.Log.LogFile.MaxSize = int(rand.Int31n(10)) + ch <- cfg time.Sleep(10 * time.Millisecond) } }) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index c753cd87..cc722f36 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -5,6 +5,7 @@ package backend import ( "bytes" + "context" "crypto/tls" "encoding/binary" "fmt" @@ -82,9 +83,9 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili return nil } -type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) +type backendIOGetter func(ctx context.Context, cctx ConnContext, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) -func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnContext, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, +func (auth *Authenticator) handshakeFirstTime(ctx context.Context, logger *zap.Logger, cctx ConnContext, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() @@ -159,7 +160,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte RECONNECT: // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. - backendIO, err := getBackendIO(cctx, auth, clientResp) + backendIO, err := getBackendIO(ctx, cctx, clientResp) if err != nil { return err } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 8ea4d256..2103685a 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -100,7 +100,7 @@ func (cfg *BCConfig) check() { // - If it retries after each command: the latency will be unacceptable afterwards if it always fails. // - If it stops receiving signals: the previous new backend may be abnormal but the next new backend may be good. type BackendConnManager struct { - // processLock makes redirecting and command processing exclusive. + // processLock makes all processes exclusive. processLock sync.Mutex wg waitgroup.WaitGroup // signalReceived is used to notify the signal processing goroutine. @@ -110,10 +110,11 @@ type BackendConnManager struct { eventReceiver unsafe.Pointer config *BCConfig logger *zap.Logger - // It will be set to nil after migration. + // Redirect() sets it without lock. It will be set to nil after migration. redirectInfo atomic.Pointer[router.BackendInst] // redirectResCh is used to notify the event receiver asynchronously. - redirectResCh chan *redirectResult + redirectResCh chan *redirectResult + // GracefulClose() sets it without lock. closeStatus atomic.Int32 checkBackendTicker *time.Ticker // cancelFunc is used to cancel the signal processing goroutine. @@ -163,9 +164,13 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe defer mgr.processLock.Unlock() mgr.backendTLS = backendTLSConfig - mgr.clientIO = clientIO - err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) + + if mgr.closeStatus.Load() >= statusNotifyClose { + mgr.quitSource = SrcProxyQuit + return errors.New("graceful shutdown before connecting") + } + err := mgr.authenticator.handshakeFirstTime(ctx, mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) if err != nil { src := Error2Source(err) mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err, src) @@ -188,7 +193,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe return nil } -func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { +func (mgr *BackendConnManager) getBackendIO(ctx context.Context, cctx ConnContext, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { r, err := mgr.handshakeHandler.GetRouter(cctx, resp) if err != nil { return nil, errors.Wrap(ErrProxyErr, err) @@ -196,7 +201,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato // Reasons to wait: // - The TiDB instances may not be initialized yet // - One TiDB may be just shut down and another is just started but not ready yet - bctx, cancel := context.WithTimeout(context.Background(), mgr.config.ConnectTimeout) + bctx, cancel := context.WithTimeout(ctx, mgr.config.ConnectTimeout) selector := r.GetBackendSelector() startTime := time.Now() var addr string @@ -259,9 +264,11 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato // ExecuteCmd forwards messages between the client and the backend. // If it finds that the session is ready for redirection, it migrates the session. func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (err error) { + mgr.processLock.Lock() defer func() { mgr.setQuitSourceByErr(err) mgr.handshakeHandler.OnTraffic(mgr) + mgr.processLock.Unlock() }() if len(request) < 1 { err = mysql.ErrMalformPacket @@ -269,13 +276,13 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( } cmd := pnet.Command(request[0]) startTime := time.Now() - mgr.processLock.Lock() - defer mgr.processLock.Unlock() - switch mgr.closeStatus.Load() { - case statusClosing, statusClosed: + // Once the request is accepted, it's treated in the transaction, so we don't check graceful shutdown here. + if mgr.closeStatus.Load() >= statusClosing { return } + // The query may last over CheckBackendInterval. In this case we don't need to check the backend after the query. + mgr.checkBackendTicker.Stop() defer mgr.resetCheckBackendTicker() waitingRedirect := mgr.redirectInfo.Load() != nil var holdRequest bool @@ -400,8 +407,7 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context) { // tryRedirect tries to migrate the session if the session is redirect-able. // NOTE: processLock should be held before calling this function. func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { - switch mgr.closeStatus.Load() { - case statusNotifyClose, statusClosing, statusClosed: + if mgr.closeStatus.Load() >= statusNotifyClose || ctx.Err() != nil { return } backendInst := mgr.redirectInfo.Load() @@ -442,6 +448,10 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } return } + if ctx.Err() != nil { + rs.err = ctx.Err() + return + } if rs.err = mgr.updateAuthInfoFromSessionStates(hack.Slice(sessionStates)); rs.err != nil { return } @@ -492,8 +502,7 @@ func (mgr *BackendConnManager) updateAuthInfoFromSessionStates(sessionStates []b // Note that the caller requires the function to be non-blocking. func (mgr *BackendConnManager) Redirect(backendInst router.BackendInst) bool { // NOTE: BackendConnManager may be closing concurrently because of no lock. - switch mgr.closeStatus.Load() { - case statusNotifyClose, statusClosing, statusClosed: + if mgr.closeStatus.Load() >= statusNotifyClose { return false } mgr.redirectInfo.Store(&backendInst) @@ -523,12 +532,13 @@ func (mgr *BackendConnManager) notifyRedirectResult(ctx context.Context, rs *red // GracefulClose waits for the end of the transaction and closes the session. func (mgr *BackendConnManager) GracefulClose() { - mgr.closeStatus.Store(statusNotifyClose) - mgr.signalReceived <- signalTypeGracefulClose + if mgr.closeStatus.CompareAndSwap(statusActive, statusNotifyClose) { + mgr.signalReceived <- signalTypeGracefulClose + } } func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) { - if mgr.closeStatus.Load() != statusNotifyClose { + if mgr.closeStatus.Load() != statusNotifyClose || ctx.Err() != nil { return } if !mgr.cmdProcessor.finishedTxn() { @@ -539,17 +549,16 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) { if err := mgr.clientIO.GracefulClose(); err != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) } - mgr.closeStatus.Store(statusClosing) + mgr.closeStatus.CompareAndSwap(statusNotifyClose, statusClosing) } func (mgr *BackendConnManager) checkBackendActive() { - switch mgr.closeStatus.Load() { - case statusClosing, statusClosed: - return - } - mgr.processLock.Lock() defer mgr.processLock.Unlock() + + if mgr.closeStatus.Load() >= statusNotifyClose { + return + } backendIO := mgr.backendIO.Load() if !backendIO.IsPeerActive() { mgr.logger.Info("backend connection is closed, close client connection", @@ -558,7 +567,7 @@ func (mgr *BackendConnManager) checkBackendActive() { if err := mgr.clientIO.GracefulClose(); err != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) } - mgr.closeStatus.Store(statusClosing) + mgr.closeStatus.CompareAndSwap(statusActive, statusClosing) } } @@ -618,6 +627,17 @@ func (mgr *BackendConnManager) Value(key any) any { // Close releases all resources. func (mgr *BackendConnManager) Close() error { + // BackendConnMgr may close even before connecting, so protect the members with a lock. + mgr.processLock.Lock() + defer func() { + mgr.processLock.Unlock() + // Wait out of the lock to avoid deadlock. + mgr.wg.Wait() + }() + if mgr.closeStatus.Load() >= statusClosed { + return nil + } + mgr.closeStatus.Store(statusClosing) if mgr.checkBackendTicker != nil { mgr.checkBackendTicker.Stop() @@ -626,18 +646,16 @@ func (mgr *BackendConnManager) Close() error { mgr.cancelFunc() mgr.cancelFunc = nil } - mgr.wg.Wait() + // OnConnClose may read ServerAddr(), so call it before closing backendIO. handErr := mgr.handshakeHandler.OnConnClose(mgr, mgr.quitSource) var connErr error var addr string - mgr.processLock.Lock() if backendIO := mgr.backendIO.Swap(nil); backendIO != nil { addr = backendIO.RemoteAddr().String() connErr = backendIO.Close() } - mgr.processLock.Unlock() eventReceiver := mgr.getEventReceiver() if eventReceiver != nil { @@ -690,16 +708,20 @@ func (mgr *BackendConnManager) setQuitSourceByErr(err error) { mgr.quitSource = Error2Source(err) } +// UpdateLogger add fields to the logger. +// Note: it should be called within the lock. func (mgr *BackendConnManager) UpdateLogger(fields ...zap.Field) { mgr.logger = mgr.logger.With(fields...) } // ConnInfo returns detailed info of the connection, which should not be logged too many times. func (mgr *BackendConnManager) ConnInfo() []zap.Field { + mgr.processLock.Lock() var fields []zap.Field if mgr.authenticator != nil { fields = mgr.authenticator.ConnInfo() } + mgr.processLock.Unlock() fields = append(fields, zap.String("backend_addr", mgr.ServerAddr())) return fields } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index e5fcbb5f..483c6037 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -742,18 +742,11 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) { return nil }, }, - // 1st handshake - { - client: ts.mc.authenticate, - proxy: ts.firstHandshake4Proxy, - backend: ts.handshake4Backend, - }, - // it will then automatically close - { - proxy: ts.checkConnClosed4Proxy, - }, + // connect fails { proxy: func(clientIO, backendIO *pnet.PacketIO) error { + err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) + require.Error(ts.t, err) require.Equal(t, SrcProxyQuit, ts.mp.QuitSource()) return nil }, @@ -893,7 +886,7 @@ func TestGetBackendIO(t *testing.T) { require.NoError(t, cn.Close()) } }) - io, err := mgr.getBackendIO(mgr, mgr.authenticator, nil) + io, err := mgr.getBackendIO(context.Background(), mgr, nil) if err == nil { require.NoError(t, io.Close()) } @@ -1148,3 +1141,81 @@ func TestBackendStatusChange(t *testing.T) { ts.runTests(runners) } + +func TestCloseWhileConnect(t *testing.T) { + ts := newBackendMgrTester(t) + runners := []runner{ + // 1st handshake while force close + { + client: ts.mc.authenticate, + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + go func() { + require.NoError(ts.t, ts.mp.BackendConnManager.Close()) + }() + err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) + if err == nil { + mer := newMockEventReceiver() + ts.mp.SetEventReceiver(mer) + } + return err + }, + backend: ts.handshake4Backend, + }, + } + + ts.runTests(runners) +} + +func TestCloseWhileExecute(t *testing.T) { + ts := newBackendMgrTester(t) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // execute cmd while force close + { + client: ts.mc.request, + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + clientIO.ResetSequence() + request, err := clientIO.ReadPacket() + if err != nil { + return err + } + go func() { + require.NoError(ts.t, ts.mp.BackendConnManager.Close()) + }() + return ts.mp.ExecuteCmd(context.Background(), request) + }, + backend: ts.startTxn4Backend, + }, + } + + ts.runTests(runners) +} + +func TestCloseWhileGracefulClose(t *testing.T) { + ts := newBackendMgrTester(t) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // graceful close while force close + { + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + go func() { + require.NoError(ts.t, ts.mp.BackendConnManager.Close()) + }() + ts.mp.GracefulClose() + return nil + }, + }, + } + + ts.runTests(runners) +} diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 1342852a..f0f219e1 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -4,6 +4,8 @@ package backend import ( + "context" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tiproxy/lib/util/errors" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" @@ -121,6 +123,8 @@ func Error2Source(err error) ErrorSource { case pnet.IsMySQLError(err): // ErrClientAuthFail and ErrBackendHandshake may also contain MySQL error. return SrcClientSQLErr + case errors.Is(err, context.Canceled): + return SrcProxyQuit default: // All other untracked errors are proxy errors. return SrcProxyErr diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 98e2cee5..4a144db7 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -26,6 +26,8 @@ const ( var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) var _ HandshakeHandler = (*CustomHandshakeHandler)(nil) +// ConnContext saves the connection attributes that are read by HandshakeHandler. +// These interfaces should not request for locks because HandshakeHandler already holds the lock. type ConnContext interface { ClientAddr() string ServerAddr() string @@ -36,6 +38,8 @@ type ConnContext interface { Value(key any) any } +// HandshakeHandler contains the hooks that are called during the connection lifecycle. +// All the interfaces should be called within a lock so that the interfaces of ConnContext are thread-safe. type HandshakeHandler interface { HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error HandleHandshakeErr(ctx ConnContext, err *mysql.MyError) bool // return true means retry connect diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 52a89680..49192058 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -4,6 +4,7 @@ package backend import ( + "context" "crypto/tls" "fmt" "testing" @@ -60,9 +61,10 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.authenticator.handshakeFirstTime(mp.logger, mp, clientIO, mp.handshakeHandler, func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { - return backendIO, nil - }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { + if err := mp.authenticator.handshakeFirstTime(context.Background(), mp.logger, mp, clientIO, mp.handshakeHandler, + func(ctx context.Context, cctx ConnContext, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { + return backendIO, nil + }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err } mp.cmdProcessor.capability = mp.authenticator.capability diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 42706768..9bbea062 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -152,6 +152,13 @@ func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) { func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { s.mu.Lock() + + if s.mu.status >= statusWaitShutdown { + s.mu.Unlock() + s.logger.Warn("server is shutting down while creating the connection", zap.String("client_addr", conn.RemoteAddr().Network()), zap.Error(conn.Close())) + return + } + conns := uint64(len(s.mu.clients)) maxConns := s.mu.maxConnections tcpKeepAlive := s.mu.tcpKeepAlive