9
9
"google.golang.org/grpc"
10
10
11
11
"github.com/ydb-platform/ydb-go-sdk/v3/config"
12
+ "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/cluster"
12
13
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
13
14
"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
14
15
"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
@@ -26,8 +27,6 @@ import (
26
27
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
27
28
)
28
29
29
- var ErrNoEndpoints = xerrors .Wrap (fmt .Errorf ("no endpoints" ))
30
-
31
30
type discoveryClient interface {
32
31
closer.Closer
33
32
@@ -40,9 +39,12 @@ type Balancer struct {
40
39
pool * conn.Pool
41
40
discoveryClient discoveryClient
42
41
discoveryRepeater repeater.Repeater
43
- localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
44
42
45
- connectionsState atomic.Pointer [connectionsState ]
43
+ cluster atomic.Pointer [cluster.Cluster ]
44
+ conns xsync.Map [endpoint.Endpoint , conn.Conn ]
45
+ banned xsync.Set [endpoint.Endpoint ]
46
+
47
+ localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
46
48
47
49
mu xsync.RWMutex
48
50
onApplyDiscoveredEndpoints []func (ctx context.Context , endpoints []endpoint.Info )
@@ -124,19 +126,48 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
124
126
}
125
127
126
128
func (b * Balancer ) applyDiscoveredEndpoints (ctx context.Context , newest []endpoint.Endpoint , localDC string ) {
127
- var (
128
- onDone = trace .DriverOnBalancerUpdate (
129
- b .driverConfig .Trace (), & ctx ,
130
- stack .FunctionID (
131
- "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
132
- b .config .DetectLocalDC ,
133
- )
134
- previous = b .connections ().All ()
129
+ onDone := trace .DriverOnBalancerUpdate (
130
+ b .driverConfig .Trace (), & ctx ,
131
+ stack .FunctionID (
132
+ "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
133
+ b .config .DetectLocalDC ,
135
134
)
135
+
136
+ state := cluster .New (newest ,
137
+ cluster .WithFilter (func (e endpoint.Info ) bool {
138
+ if b .config .Filter == nil {
139
+ return true
140
+ }
141
+
142
+ return b .config .Filter .Allow (balancerConfig.Info {SelfLocation : localDC }, e )
143
+ }),
144
+ cluster .WithFallback (b .config .AllowFallback ),
145
+ )
146
+
147
+ previous := b .cluster .Swap (state )
148
+
149
+ _ , added , dropped := xslices .Diff (previous .All (), newest , func (lhs , rhs endpoint.Endpoint ) int {
150
+ return strings .Compare (lhs .Address (), rhs .Address ())
151
+ })
152
+
153
+ for _ , e := range dropped {
154
+ c , ok := b .conns .Extract (e )
155
+ if ! ok {
156
+ panic ("wrong balancer state" )
157
+ }
158
+ b .pool .Put (ctx , c )
159
+ }
160
+
161
+ for _ , e := range added {
162
+ cc , err := b .pool .Get (ctx , e )
163
+ if err != nil {
164
+ b .banned .Add (e )
165
+ } else {
166
+ b .conns .Set (e , cc )
167
+ }
168
+ }
169
+
136
170
defer func () {
137
- _ , added , dropped := xslices .Diff (previous , newest , func (lhs , rhs endpoint.Endpoint ) int {
138
- return strings .Compare (lhs .Address (), rhs .Address ())
139
- })
140
171
onDone (
141
172
xslices .Transform (newest , func (t endpoint.Endpoint ) trace.EndpointInfo { return t }),
142
173
xslices .Transform (added , func (t endpoint.Endpoint ) trace.EndpointInfo { return t }),
@@ -145,25 +176,13 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi
145
176
)
146
177
}()
147
178
148
- connections := endpointsToConnections (b .pool , newest )
149
- for _ , c := range connections {
150
- b .pool .Allow (ctx , c )
151
- c .Endpoint ().Touch ()
152
- }
153
-
154
- info := balancerConfig.Info {SelfLocation : localDC }
155
- state := newConnectionsState (connections , b .config .Filter , info , b .config .AllowFallback )
156
-
157
- endpointsInfo := make ([]endpoint.Info , len (newest ))
158
- for i , e := range newest {
159
- endpointsInfo [i ] = e
160
- }
161
-
162
- b .connectionsState .Store (state )
179
+ endpoints := xslices .Transform (newest , func (e endpoint.Endpoint ) endpoint.Info {
180
+ return e
181
+ })
163
182
164
183
b .mu .WithLock (func () {
165
184
for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
166
- onApplyDiscoveredEndpoints (ctx , endpointsInfo )
185
+ onApplyDiscoveredEndpoints (ctx , endpoints )
167
186
}
168
187
})
169
188
}
@@ -212,18 +231,20 @@ func New(
212
231
onDone (finalErr )
213
232
}()
214
233
234
+ cc , err := pool .Get (ctx , endpoint .New (driverConfig .Endpoint ()))
235
+ if err != nil {
236
+ return nil , xerrors .WithStackTrace (err )
237
+ }
238
+
215
239
b = & Balancer {
216
- driverConfig : driverConfig ,
217
- pool : pool ,
218
- discoveryClient : internalDiscovery .New (ctx , pool .Get (
219
- endpoint .New (driverConfig .Endpoint ()),
220
- ), discoveryConfig ),
240
+ config : balancerConfig.Config {},
241
+ driverConfig : driverConfig ,
242
+ pool : pool ,
243
+ discoveryClient : internalDiscovery .New (ctx , cc , discoveryConfig ),
221
244
localDCDetector : detectLocalDC ,
222
245
}
223
246
224
- if config := driverConfig .Balancer (); config == nil {
225
- b .config = balancerConfig.Config {}
226
- } else {
247
+ if config := driverConfig .Balancer (); config != nil {
227
248
b .config = * config
228
249
}
229
250
@@ -289,10 +310,10 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
289
310
defer func () {
290
311
if err == nil {
291
312
if cc .GetState () == conn .Banned {
292
- b .pool . Allow ( ctx , cc )
313
+ b .banned . Remove ( cc . Endpoint () )
293
314
}
294
- } else if xerrors . MustPessimizeEndpoint (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
295
- b .pool . Ban ( ctx , cc , err )
315
+ } else if conn . IsBadConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
316
+ b .banned . Add ( cc . Endpoint () )
296
317
}
297
318
}()
298
319
@@ -319,53 +340,46 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
319
340
return nil
320
341
}
321
342
322
- func (b * Balancer ) connections () * connectionsState {
323
- return b .connectionsState .Load ()
324
- }
325
-
326
343
func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
327
- onDone := trace .DriverOnBalancerChooseEndpoint (
328
- b .driverConfig .Trace (), & ctx ,
329
- stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn" ),
344
+ var (
345
+ onDone = trace .DriverOnBalancerChooseEndpoint (
346
+ b .driverConfig .Trace (), & ctx ,
347
+ stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn" ),
348
+ )
349
+ state = b .cluster .Load ()
330
350
)
351
+
331
352
defer func () {
353
+ b .cluster .Store (state )
354
+
355
+ if b .discoveryRepeater != nil {
356
+ b .discoveryRepeater .Force ()
357
+ }
358
+
332
359
if err == nil {
333
360
onDone (c .Endpoint (), nil )
334
361
} else {
335
362
onDone (nil , err )
336
363
}
337
364
}()
338
365
339
- if err = ctx .Err (); err != nil {
340
- return nil , xerrors .WithStackTrace (err )
341
- }
342
-
343
- var (
344
- state = b .connections ()
345
- failedCount int
346
- )
347
-
348
- defer func () {
349
- if failedCount * 2 > state .PreferredCount () && b .discoveryRepeater != nil {
350
- b .discoveryRepeater .Force ()
366
+ for attempts := 1 ; ; attempts ++ {
367
+ if err = ctx .Err (); err != nil {
368
+ return nil , xerrors .WithStackTrace (err )
351
369
}
352
- }()
353
370
354
- c , failedCount = state .GetConnection (ctx )
355
- if c = = nil {
356
- return nil , xerrors .WithStackTrace (
357
- fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , ErrNoEndpoints , failedCount ),
358
- )
359
- }
371
+ e , err : = state .Next (ctx )
372
+ if err ! = nil {
373
+ return nil , xerrors .WithStackTrace (
374
+ fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , cluster . ErrNoEndpoints , attempts ),
375
+ )
376
+ }
360
377
361
- return c , nil
362
- }
378
+ cc , err := b .pool .Get (ctx , e )
379
+ if err == nil {
380
+ return cc , nil
381
+ }
363
382
364
- func endpointsToConnections (p * conn.Pool , endpoints []endpoint.Endpoint ) []conn.Conn {
365
- conns := make ([]conn.Conn , 0 , len (endpoints ))
366
- for _ , e := range endpoints {
367
- conns = append (conns , p .Get (e ))
383
+ b .banned .Add (e )
368
384
}
369
-
370
- return conns
371
385
}
0 commit comments