Skip to content

Commit 9ba694f

Browse files
committed
mongo: validate each shard on sharded cluster
1 parent 2dd8aec commit 9ba694f

File tree

3 files changed

+132
-8
lines changed

3 files changed

+132
-8
lines changed

flow/connectors/mongo/validate.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@ import (
88
)
99

1010
func (c *MongoConnector) ValidateCheck(ctx context.Context) error {
11-
if err := shared_mongo.ValidateServerCompatibility(ctx, c.client); err != nil {
11+
if err := shared_mongo.ValidateServerCompatibility(ctx, c.client, shared_mongo.Credentials{
12+
Username: c.config.Username,
13+
Password: c.config.Password,
14+
DisableTls: c.config.DisableTls,
15+
RootCa: c.config.RootCa,
16+
TlsHost: c.config.TlsHost,
17+
}); err != nil {
1218
return err
1319
}
1420

@@ -24,7 +30,13 @@ func (c *MongoConnector) ValidateMirrorSource(ctx context.Context, cfg *protos.F
2430
return nil
2531
}
2632

27-
if err := shared_mongo.ValidateOplogRetention(ctx, c.client); err != nil {
33+
if err := shared_mongo.ValidateOplogRetention(ctx, c.client, shared_mongo.Credentials{
34+
Username: c.config.Username,
35+
Password: c.config.Password,
36+
DisableTls: c.config.DisableTls,
37+
RootCa: c.config.RootCa,
38+
TlsHost: c.config.TlsHost,
39+
}); err != nil {
2840
return err
2941
}
3042

flow/shared/mongo/commands.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ func GetHelloResponse(ctx context.Context, client *mongo.Client) (*HelloResponse
6868
return runCommand[HelloResponse](ctx, client, "hello")
6969
}
7070

71+
type Shard struct {
72+
Host string `bson:"host"`
73+
}
74+
75+
type ListShards struct {
76+
Shards []Shard `bson:"shards"`
77+
Ok int `bson:"ok"`
78+
}
79+
80+
func GetListShards(ctx context.Context, client *mongo.Client) (*ListShards, error) {
81+
return runCommand[ListShards](ctx, client, "listShards")
82+
}
83+
7184
func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (*T, error) {
7285
singleResult := client.Database("admin").RunCommand(ctx, bson.D{
7386
bson.E{Key: command, Value: 1},

flow/shared/mongo/validation.go

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@ package mongo
22

33
import (
44
"context"
5+
"crypto/tls"
56
"errors"
67
"fmt"
8+
"strings"
9+
"sync"
710

11+
"github.com/PeerDB-io/peerdb/flow/shared"
812
"go.mongodb.org/mongo-driver/v2/mongo"
13+
"go.mongodb.org/mongo-driver/v2/mongo/options"
914
)
1015

1116
const (
@@ -16,7 +21,15 @@ const (
1621
ShardedCluster = "ShardedCluster"
1722
)
1823

19-
func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) error {
24+
type Credentials struct {
25+
RootCa *string
26+
Username string
27+
Password string
28+
TlsHost string
29+
DisableTls bool
30+
}
31+
32+
func ValidateServerCompatibility(ctx context.Context, client *mongo.Client, credentials Credentials) error {
2033
buildInfo, err := GetBuildInfo(ctx, client)
2134
if err != nil {
2235
return err
@@ -48,8 +61,7 @@ func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) erro
4861
if topologyType == ReplicaSet {
4962
return validateStorageEngine(ctx, client)
5063
} else {
51-
// TODO: run validation on shard
52-
return nil
64+
return runOnShardsInParallel(ctx, client, credentials, validateStorageEngine)
5365
}
5466
}
5567

@@ -79,7 +91,7 @@ func ValidateUserRoles(ctx context.Context, client *mongo.Client) error {
7991
return nil
8092
}
8193

82-
func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error {
94+
func ValidateOplogRetention(ctx context.Context, client *mongo.Client, credentials Credentials) error {
8395
validateOplogRetention := func(instanceCtx context.Context, instanceClient *mongo.Client) error {
8496
ss, err := GetServerStatus(instanceCtx, instanceClient)
8597
if err != nil {
@@ -100,8 +112,7 @@ func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error {
100112
if topology == ReplicaSet {
101113
return validateOplogRetention(ctx, client)
102114
} else {
103-
// TODO: run validation on shard
104-
return nil
115+
return runOnShardsInParallel(ctx, client, credentials, validateOplogRetention)
105116
}
106117
}
107118

@@ -124,3 +135,91 @@ func GetTopologyType(ctx context.Context, client *mongo.Client) (string, error)
124135
}
125136
return "", errors.New("topology type must be ReplicaSet or ShardedCluster")
126137
}
138+
139+
func runOnShardsInParallel(
140+
ctx context.Context,
141+
client *mongo.Client,
142+
credentials Credentials,
143+
runCommand func(ctx context.Context, client *mongo.Client) error,
144+
) error {
145+
res, err := GetListShards(ctx, client)
146+
if err != nil {
147+
return err
148+
}
149+
150+
if res.Ok != 1 || len(res.Shards) == 0 {
151+
return errors.New("invalid shards")
152+
}
153+
154+
hosts := getUniqueClusterHosts(res.Shards)
155+
hostsErrors := make([]error, len(hosts))
156+
var wg sync.WaitGroup
157+
for idx, host := range hosts {
158+
wg.Add(1)
159+
go func(i int, h string) {
160+
defer wg.Done()
161+
162+
shardOpts := options.Client().
163+
ApplyURI("mongodb://" + h).
164+
SetDirect(true).
165+
SetAuth(options.Credential{
166+
Username: credentials.Username,
167+
Password: credentials.Password,
168+
})
169+
170+
if !credentials.DisableTls {
171+
tlsConfig, err := shared.CreateTlsConfig(tls.VersionTLS12, credentials.RootCa, "", credentials.TlsHost, false)
172+
if err != nil {
173+
hostsErrors[i] = fmt.Errorf("host %s TLS config error: %w", h, err)
174+
return
175+
}
176+
shardOpts.SetTLSConfig(tlsConfig)
177+
}
178+
179+
shardClient, err := mongo.Connect(shardOpts)
180+
if err != nil {
181+
hostsErrors[i] = fmt.Errorf("host %s connect error: %w", h, err)
182+
return
183+
}
184+
defer shardClient.Disconnect(ctx) //nolint:errcheck
185+
186+
if err := runCommand(ctx, shardClient); err != nil {
187+
hostsErrors[i] = fmt.Errorf("host %s command error: %w", h, err)
188+
return
189+
}
190+
}(idx, host)
191+
}
192+
wg.Wait()
193+
194+
for _, err = range hostsErrors {
195+
if err != nil {
196+
return err
197+
}
198+
}
199+
200+
return nil
201+
}
202+
203+
func getUniqueClusterHosts(shards []Shard) []string {
204+
hostSet := make(map[string]bool)
205+
for _, shard := range shards {
206+
hosts := shard.Host
207+
if slashIdx := strings.Index(hosts, "/"); slashIdx != -1 {
208+
hosts = hosts[slashIdx+1:]
209+
}
210+
211+
for _, host := range strings.Split(hosts, ",") {
212+
host = strings.TrimSpace(host)
213+
if host != "" {
214+
hostSet[host] = true
215+
}
216+
}
217+
}
218+
219+
hosts := make([]string, 0, len(hostSet))
220+
for host := range hostSet {
221+
hosts = append(hosts, host)
222+
}
223+
224+
return hosts
225+
}

0 commit comments

Comments
 (0)