Skip to content

Commit 0628d5d

Browse files
committed
WIP
1 parent 1a85c10 commit 0628d5d

17 files changed

+316
-1020
lines changed

driver.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery"
2121
discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config"
2222
"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
23-
"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
2423
internalQuery "github.com/ydb-platform/ydb-go-sdk/v3/internal/query"
2524
queryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/config"
2625
internalRatelimiter "github.com/ydb-platform/ydb-go-sdk/v3/internal/ratelimiter"
@@ -488,7 +487,7 @@ func (d *Driver) connect(ctx context.Context) (err error) {
488487

489488
d.discovery = xsync.OnceValue(func() (*internalDiscovery.Client, error) {
490489
return internalDiscovery.New(xcontext.ValueOnly(ctx),
491-
d.pool.Get(endpoint.New(d.config.Endpoint())),
490+
d.balancer,
492491
discoveryConfig.New(
493492
append(
494493
// prepend common params from root config

internal/balancer/balancer.go

+90-76
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"google.golang.org/grpc"
1010

1111
"github.com/ydb-platform/ydb-go-sdk/v3/config"
12+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/cluster"
1213
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
1314
"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
1415
"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
@@ -26,8 +27,6 @@ import (
2627
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
2728
)
2829

29-
var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints"))
30-
3130
type discoveryClient interface {
3231
closer.Closer
3332

@@ -40,9 +39,12 @@ type Balancer struct {
4039
pool *conn.Pool
4140
discoveryClient discoveryClient
4241
discoveryRepeater repeater.Repeater
43-
localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
4442

45-
connectionsState atomic.Pointer[connectionsState]
43+
cluster atomic.Pointer[cluster.Cluster]
44+
conns xsync.Map[endpoint.Endpoint, conn.Conn]
45+
banned xsync.Set[endpoint.Endpoint]
46+
47+
localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
4648

4749
mu xsync.RWMutex
4850
onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
@@ -124,19 +126,48 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
124126
}
125127

126128
func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoint.Endpoint, localDC string) {
127-
var (
128-
onDone = trace.DriverOnBalancerUpdate(
129-
b.driverConfig.Trace(), &ctx,
130-
stack.FunctionID(
131-
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
132-
b.config.DetectLocalDC,
133-
)
134-
previous = b.connections().All()
129+
onDone := trace.DriverOnBalancerUpdate(
130+
b.driverConfig.Trace(), &ctx,
131+
stack.FunctionID(
132+
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
133+
b.config.DetectLocalDC,
135134
)
135+
136+
state := cluster.New(newest,
137+
cluster.WithFilter(func(e endpoint.Info) bool {
138+
if b.config.Filter == nil {
139+
return true
140+
}
141+
142+
return b.config.Filter.Allow(balancerConfig.Info{SelfLocation: localDC}, e)
143+
}),
144+
cluster.WithFallback(b.config.AllowFallback),
145+
)
146+
147+
previous := b.cluster.Swap(state)
148+
149+
_, added, dropped := xslices.Diff(previous.All(), newest, func(lhs, rhs endpoint.Endpoint) int {
150+
return strings.Compare(lhs.Address(), rhs.Address())
151+
})
152+
153+
for _, e := range dropped {
154+
c, ok := b.conns.Extract(e)
155+
if !ok {
156+
panic("wrong balancer state")
157+
}
158+
b.pool.Put(ctx, c)
159+
}
160+
161+
for _, e := range added {
162+
cc, err := b.pool.Get(ctx, e)
163+
if err != nil {
164+
b.banned.Add(e)
165+
} else {
166+
b.conns.Set(e, cc)
167+
}
168+
}
169+
136170
defer func() {
137-
_, added, dropped := xslices.Diff(previous, newest, func(lhs, rhs endpoint.Endpoint) int {
138-
return strings.Compare(lhs.Address(), rhs.Address())
139-
})
140171
onDone(
141172
xslices.Transform(newest, func(t endpoint.Endpoint) trace.EndpointInfo { return t }),
142173
xslices.Transform(added, func(t endpoint.Endpoint) trace.EndpointInfo { return t }),
@@ -145,25 +176,13 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi
145176
)
146177
}()
147178

148-
connections := endpointsToConnections(b.pool, newest)
149-
for _, c := range connections {
150-
b.pool.Allow(ctx, c)
151-
c.Endpoint().Touch()
152-
}
153-
154-
info := balancerConfig.Info{SelfLocation: localDC}
155-
state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback)
156-
157-
endpointsInfo := make([]endpoint.Info, len(newest))
158-
for i, e := range newest {
159-
endpointsInfo[i] = e
160-
}
161-
162-
b.connectionsState.Store(state)
179+
endpoints := xslices.Transform(newest, func(e endpoint.Endpoint) endpoint.Info {
180+
return e
181+
})
163182

164183
b.mu.WithLock(func() {
165184
for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
166-
onApplyDiscoveredEndpoints(ctx, endpointsInfo)
185+
onApplyDiscoveredEndpoints(ctx, endpoints)
167186
}
168187
})
169188
}
@@ -212,18 +231,20 @@ func New(
212231
onDone(finalErr)
213232
}()
214233

234+
cc, err := pool.Get(ctx, endpoint.New(driverConfig.Endpoint()))
235+
if err != nil {
236+
return nil, xerrors.WithStackTrace(err)
237+
}
238+
215239
b = &Balancer{
216-
driverConfig: driverConfig,
217-
pool: pool,
218-
discoveryClient: internalDiscovery.New(ctx, pool.Get(
219-
endpoint.New(driverConfig.Endpoint()),
220-
), discoveryConfig),
240+
config: balancerConfig.Config{},
241+
driverConfig: driverConfig,
242+
pool: pool,
243+
discoveryClient: internalDiscovery.New(ctx, cc, discoveryConfig),
221244
localDCDetector: detectLocalDC,
222245
}
223246

224-
if config := driverConfig.Balancer(); config == nil {
225-
b.config = balancerConfig.Config{}
226-
} else {
247+
if config := driverConfig.Balancer(); config != nil {
227248
b.config = *config
228249
}
229250

@@ -289,10 +310,10 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
289310
defer func() {
290311
if err == nil {
291312
if cc.GetState() == conn.Banned {
292-
b.pool.Allow(ctx, cc)
313+
b.banned.Remove(cc.Endpoint())
293314
}
294-
} else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
295-
b.pool.Ban(ctx, cc, err)
315+
} else if conn.IsBadConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
316+
b.banned.Add(cc.Endpoint())
296317
}
297318
}()
298319

@@ -319,53 +340,46 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
319340
return nil
320341
}
321342

322-
func (b *Balancer) connections() *connectionsState {
323-
return b.connectionsState.Load()
324-
}
325-
326343
func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
327-
onDone := trace.DriverOnBalancerChooseEndpoint(
328-
b.driverConfig.Trace(), &ctx,
329-
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn"),
344+
var (
345+
onDone = trace.DriverOnBalancerChooseEndpoint(
346+
b.driverConfig.Trace(), &ctx,
347+
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn"),
348+
)
349+
state = b.cluster.Load()
330350
)
351+
331352
defer func() {
353+
b.cluster.Store(state)
354+
355+
if b.discoveryRepeater != nil {
356+
b.discoveryRepeater.Force()
357+
}
358+
332359
if err == nil {
333360
onDone(c.Endpoint(), nil)
334361
} else {
335362
onDone(nil, err)
336363
}
337364
}()
338365

339-
if err = ctx.Err(); err != nil {
340-
return nil, xerrors.WithStackTrace(err)
341-
}
342-
343-
var (
344-
state = b.connections()
345-
failedCount int
346-
)
347-
348-
defer func() {
349-
if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil {
350-
b.discoveryRepeater.Force()
366+
for attempts := 1; ; attempts++ {
367+
if err = ctx.Err(); err != nil {
368+
return nil, xerrors.WithStackTrace(err)
351369
}
352-
}()
353370

354-
c, failedCount = state.GetConnection(ctx)
355-
if c == nil {
356-
return nil, xerrors.WithStackTrace(
357-
fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount),
358-
)
359-
}
371+
e, err := state.Next(ctx)
372+
if err != nil {
373+
return nil, xerrors.WithStackTrace(
374+
fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", cluster.ErrNoEndpoints, attempts),
375+
)
376+
}
360377

361-
return c, nil
362-
}
378+
cc, err := b.pool.Get(ctx, e)
379+
if err == nil {
380+
return cc, nil
381+
}
363382

364-
func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn {
365-
conns := make([]conn.Conn, 0, len(endpoints))
366-
for _, e := range endpoints {
367-
conns = append(conns, p.Get(e))
383+
b.banned.Add(e)
368384
}
369-
370-
return conns
371385
}

0 commit comments

Comments
 (0)