@@ -41,7 +41,9 @@ type Balancer struct {
4141 localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
4242
4343 mu xsync.RWMutex
44- connectionsState * connectionsState
44+ connectionsState * connectionsState [conn.Conn ]
45+
46+ closed chan struct {}
4547
4648 onApplyDiscoveredEndpoints []func (ctx context.Context , endpoints []endpoint.Info )
4749}
@@ -133,7 +135,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
133135 return nil
134136}
135137
136- func endpointsDiff (newestEndpoints []endpoint.Endpoint , previousConns []conn.Conn ) (
138+ func endpointsDiff (newestEndpoints []endpoint.Endpoint , previousConns []conn.Info ) (
137139 nodes []trace.EndpointInfo ,
138140 added []trace.EndpointInfo ,
139141 dropped []trace.EndpointInfo ,
@@ -178,7 +180,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
178180 "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
179181 b .config .DetectLocalDC ,
180182 )
181- previousConns []conn.Conn
183+ previousConns []conn.Info
182184 )
183185 defer func () {
184186 nodes , added , dropped := endpointsDiff (endpoints , previousConns )
@@ -187,7 +189,9 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
187189
188190 connections := endpointsToConnections (b .pool , endpoints )
189191 for _ , c := range connections {
190- b .pool .Allow (ctx , c )
192+ if c .State () == conn .Banned {
193+ b .pool .Unban (ctx , c )
194+ }
191195 c .Endpoint ().Touch ()
192196 }
193197
@@ -201,7 +205,10 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
201205
202206 b .mu .WithLock (func () {
203207 if b .connectionsState != nil {
204- previousConns = b .connectionsState .all
208+ previousConns = make ([]conn.Info , len (b .connectionsState .all ))
209+ for i := range b .connectionsState .all {
210+ previousConns [i ] = b .connectionsState .all [i ]
211+ }
205212 }
206213 b .connectionsState = state
207214 for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
@@ -211,6 +218,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
211218}
212219
213220func (b * Balancer ) Close (ctx context.Context ) (err error ) {
221+ close (b .closed )
222+
214223 onDone := trace .DriverOnBalancerClose (
215224 b .driverConfig .Trace (), & ctx ,
216225 stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close" ),
@@ -223,6 +232,8 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
223232 b .discoveryRepeater .Stop ()
224233 }
225234
235+ b .applyDiscoveredEndpoints (ctx , nil , "" )
236+
226237 if err = b .discoveryClient .Close (ctx ); err != nil {
227238 return xerrors .WithStackTrace (err )
228239 }
@@ -258,6 +269,7 @@ func New(
258269 driverConfig : driverConfig ,
259270 pool : pool ,
260271 localDCDetector : detectLocalDC ,
272+ closed : make (chan struct {}),
261273 }
262274 d := internalDiscovery .New (ctx , pool .Get (
263275 endpoint .New (driverConfig .Endpoint ()),
@@ -300,9 +312,14 @@ func (b *Balancer) Invoke(
300312 reply interface {},
301313 opts ... grpc.CallOption ,
302314) error {
303- return b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
304- return cc .Invoke (ctx , method , args , reply , opts ... )
305- })
315+ select {
316+ case <- b .closed :
317+ return xerrors .WithStackTrace (errBalancerClosed )
318+ default :
319+ return b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
320+ return cc .Invoke (ctx , method , args , reply , opts ... )
321+ })
322+ }
306323}
307324
308325func (b * Balancer ) NewStream (
@@ -311,17 +328,22 @@ func (b *Balancer) NewStream(
311328 method string ,
312329 opts ... grpc.CallOption ,
313330) (_ grpc.ClientStream , err error ) {
314- var client grpc.ClientStream
315- err = b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
316- client , err = cc .NewStream (ctx , desc , method , opts ... )
331+ select {
332+ case <- b .closed :
333+ return nil , xerrors .WithStackTrace (errBalancerClosed )
334+ default :
335+ var client grpc.ClientStream
336+ err = b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
337+ client , err = cc .NewStream (ctx , desc , method , opts ... )
338+
339+ return err
340+ })
341+ if err == nil {
342+ return client , nil
343+ }
317344
318- return err
319- })
320- if err == nil {
321- return client , nil
345+ return nil , err
322346 }
323-
324- return nil , err
325347}
326348
327349func (b * Balancer ) wrapCall (ctx context.Context , f func (ctx context.Context , cc conn.Conn ) error ) (err error ) {
@@ -332,10 +354,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
332354
333355 defer func () {
334356 if err == nil {
335- if cc .GetState () == conn .Banned {
336- b .pool .Allow (ctx , cc )
337- }
338- } else if xerrors .MustPessimizeEndpoint (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
357+ b .pool .Unban (ctx , cc )
358+ } else if xerrors .MustBanConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
339359 b .pool .Ban (ctx , cc , err )
340360 }
341361 }()
@@ -363,7 +383,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
363383 return nil
364384}
365385
366- func (b * Balancer ) connections () * connectionsState {
386+ func (b * Balancer ) connections () * connectionsState [conn. Conn ] {
367387 b .mu .RLock ()
368388 defer b .mu .RUnlock ()
369389
@@ -401,7 +421,7 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
401421 c , failedCount = state .GetConnection (ctx )
402422 if c == nil {
403423 return nil , xerrors .WithStackTrace (
404- fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , ErrNoEndpoints , failedCount ),
424+ fmt .Errorf ("cannot get connection from Balancer after %d attempts: %w " , failedCount , ErrNoEndpoints ),
405425 )
406426 }
407427
0 commit comments