Skip to content

Commit 31f752d

Browse files
committed
mongo: validate each shard on sharded cluster
1 parent 6b262de commit 31f752d

File tree

3 files changed

+133
-8
lines changed

3 files changed

+133
-8
lines changed

flow/connectors/mongo/validate.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@ import (
88
)
99

1010
func (c *MongoConnector) ValidateCheck(ctx context.Context) error {
11-
if err := shared_mongo.ValidateServerCompatibility(ctx, c.client); err != nil {
11+
if err := shared_mongo.ValidateServerCompatibility(ctx, c.client, shared_mongo.Credentials{
12+
Username: c.config.Username,
13+
Password: c.config.Password,
14+
DisableTls: c.config.DisableTls,
15+
RootCa: c.config.RootCa,
16+
TlsHost: c.config.TlsHost,
17+
}); err != nil {
1218
return err
1319
}
1420

@@ -24,7 +30,13 @@ func (c *MongoConnector) ValidateMirrorSource(ctx context.Context, cfg *protos.F
2430
return nil
2531
}
2632

27-
if err := shared_mongo.ValidateOplogRetention(ctx, c.client); err != nil {
33+
if err := shared_mongo.ValidateOplogRetention(ctx, c.client, shared_mongo.Credentials{
34+
Username: c.config.Username,
35+
Password: c.config.Password,
36+
DisableTls: c.config.DisableTls,
37+
RootCa: c.config.RootCa,
38+
TlsHost: c.config.TlsHost,
39+
}); err != nil {
2840
return err
2941
}
3042

flow/shared/mongo/commands.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ func GetHelloResponse(ctx context.Context, client *mongo.Client) (HelloResponse,
6868
return runCommand[HelloResponse](ctx, client, "hello")
6969
}
7070

71+
type Shard struct {
72+
Host string `bson:"host"`
73+
}
74+
75+
type ListShards struct {
76+
Shards []Shard `bson:"shards"`
77+
Ok int `bson:"ok"`
78+
}
79+
80+
func GetListShards(ctx context.Context, client *mongo.Client) (ListShards, error) {
81+
return runCommand[ListShards](ctx, client, "listShards")
82+
}
83+
7184
func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (T, error) {
7285
var result T
7386
singleResult := client.Database("admin").RunCommand(ctx, bson.D{

flow/shared/mongo/validation.go

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@ package mongo
22

33
import (
44
"context"
5+
"crypto/tls"
56
"errors"
67
"fmt"
78
"slices"
9+
"strings"
10+
"sync"
811

912
"go.mongodb.org/mongo-driver/v2/mongo"
13+
"go.mongodb.org/mongo-driver/v2/mongo/options"
14+
15+
"github.com/PeerDB-io/peerdb/flow/shared"
1016
)
1117

1218
const (
@@ -19,7 +25,15 @@ const (
1925

2026
var RequiredRoles = [...]string{"readAnyDatabase", "clusterMonitor"}
2127

22-
func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) error {
28+
type Credentials struct {
29+
RootCa *string
30+
Username string
31+
Password string
32+
TlsHost string
33+
DisableTls bool
34+
}
35+
36+
func ValidateServerCompatibility(ctx context.Context, client *mongo.Client, credentials Credentials) error {
2337
buildInfo, err := GetBuildInfo(ctx, client)
2438
if err != nil {
2539
return err
@@ -51,8 +65,7 @@ func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) erro
5165
if topologyType == ReplicaSet {
5266
return validateStorageEngine(ctx, client)
5367
} else {
54-
// TODO: run validation on shard
55-
return nil
68+
return runOnShardsInParallel(ctx, client, credentials, validateStorageEngine)
5669
}
5770
}
5871

@@ -73,7 +86,7 @@ func ValidateUserRoles(ctx context.Context, client *mongo.Client) error {
7386
return nil
7487
}
7588

76-
func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error {
89+
func ValidateOplogRetention(ctx context.Context, client *mongo.Client, credentials Credentials) error {
7790
validateOplogRetention := func(instanceCtx context.Context, instanceClient *mongo.Client) error {
7891
ss, err := GetServerStatus(instanceCtx, instanceClient)
7992
if err != nil {
@@ -94,8 +107,7 @@ func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error {
94107
if topology == ReplicaSet {
95108
return validateOplogRetention(ctx, client)
96109
} else {
97-
// TODO: run validation on shard
98-
return nil
110+
return runOnShardsInParallel(ctx, client, credentials, validateOplogRetention)
99111
}
100112
}
101113

@@ -118,3 +130,91 @@ func GetTopologyType(ctx context.Context, client *mongo.Client) (string, error)
118130
}
119131
return "", errors.New("topology type must be ReplicaSet or ShardedCluster")
120132
}
133+
134+
func runOnShardsInParallel(
135+
ctx context.Context,
136+
client *mongo.Client,
137+
credentials Credentials,
138+
runCommand func(ctx context.Context, client *mongo.Client) error,
139+
) error {
140+
res, err := GetListShards(ctx, client)
141+
if err != nil {
142+
return err
143+
}
144+
145+
if res.Ok != 1 || len(res.Shards) == 0 {
146+
return errors.New("invalid shards")
147+
}
148+
149+
hosts := getUniqueClusterHosts(res.Shards)
150+
hostsErrors := make([]error, len(hosts))
151+
var wg sync.WaitGroup
152+
for idx, host := range hosts {
153+
wg.Add(1)
154+
go func(i int, h string) {
155+
defer wg.Done()
156+
157+
shardOpts := options.Client().
158+
ApplyURI("mongodb://" + h).
159+
SetDirect(true).
160+
SetAuth(options.Credential{
161+
Username: credentials.Username,
162+
Password: credentials.Password,
163+
})
164+
165+
if !credentials.DisableTls {
166+
tlsConfig, err := shared.CreateTlsConfig(tls.VersionTLS12, credentials.RootCa, "", credentials.TlsHost, false)
167+
if err != nil {
168+
hostsErrors[i] = fmt.Errorf("host %s TLS config error: %w", h, err)
169+
return
170+
}
171+
shardOpts.SetTLSConfig(tlsConfig)
172+
}
173+
174+
shardClient, err := mongo.Connect(shardOpts)
175+
if err != nil {
176+
hostsErrors[i] = fmt.Errorf("host %s connect error: %w", h, err)
177+
return
178+
}
179+
defer shardClient.Disconnect(ctx) //nolint:errcheck
180+
181+
if err := runCommand(ctx, shardClient); err != nil {
182+
hostsErrors[i] = fmt.Errorf("host %s command error: %w", h, err)
183+
return
184+
}
185+
}(idx, host)
186+
}
187+
wg.Wait()
188+
189+
for _, err = range hostsErrors {
190+
if err != nil {
191+
return err
192+
}
193+
}
194+
195+
return nil
196+
}
197+
198+
func getUniqueClusterHosts(shards []Shard) []string {
199+
hostSet := make(map[string]bool)
200+
for _, shard := range shards {
201+
hosts := shard.Host
202+
if slashIdx := strings.Index(hosts, "/"); slashIdx != -1 {
203+
hosts = hosts[slashIdx+1:]
204+
}
205+
206+
for _, host := range strings.Split(hosts, ",") {
207+
host = strings.TrimSpace(host)
208+
if host != "" {
209+
hostSet[host] = true
210+
}
211+
}
212+
}
213+
214+
hosts := make([]string, 0, len(hostSet))
215+
for host := range hostSet {
216+
hosts = append(hosts, host)
217+
}
218+
219+
return hosts
220+
}

0 commit comments

Comments
 (0)