@@ -2,11 +2,17 @@ package mongo
22
33import (
44 "context"
5+ "crypto/tls"
56 "errors"
67 "fmt"
78 "slices"
9+ "strings"
10+ "sync"
811
912 "go.mongodb.org/mongo-driver/v2/mongo"
13+ "go.mongodb.org/mongo-driver/v2/mongo/options"
14+
15+ "github.com/PeerDB-io/peerdb/flow/shared"
1016)
1117
1218const (
@@ -19,7 +25,15 @@ const (
1925
2026var RequiredRoles = [... ]string {"readAnyDatabase" , "clusterMonitor" }
2127
22- func ValidateServerCompatibility (ctx context.Context , client * mongo.Client ) error {
28+ type Credentials struct {
29+ RootCa * string
30+ Username string
31+ Password string
32+ TlsHost string
33+ DisableTls bool
34+ }
35+
36+ func ValidateServerCompatibility (ctx context.Context , client * mongo.Client , credentials Credentials ) error {
2337 buildInfo , err := GetBuildInfo (ctx , client )
2438 if err != nil {
2539 return err
@@ -51,8 +65,7 @@ func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) erro
5165 if topologyType == ReplicaSet {
5266 return validateStorageEngine (ctx , client )
5367 } else {
54- // TODO: run validation on shard
55- return nil
68+ return runOnShardsInParallel (ctx , client , credentials , validateStorageEngine )
5669 }
5770}
5871
@@ -73,7 +86,7 @@ func ValidateUserRoles(ctx context.Context, client *mongo.Client) error {
7386 return nil
7487}
7588
76- func ValidateOplogRetention (ctx context.Context , client * mongo.Client ) error {
89+ func ValidateOplogRetention (ctx context.Context , client * mongo.Client , credentials Credentials ) error {
7790 validateOplogRetention := func (instanceCtx context.Context , instanceClient * mongo.Client ) error {
7891 ss , err := GetServerStatus (instanceCtx , instanceClient )
7992 if err != nil {
@@ -94,8 +107,7 @@ func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error {
94107 if topology == ReplicaSet {
95108 return validateOplogRetention (ctx , client )
96109 } else {
97- // TODO: run validation on shard
98- return nil
110+ return runOnShardsInParallel (ctx , client , credentials , validateOplogRetention )
99111 }
100112}
101113
@@ -118,3 +130,91 @@ func GetTopologyType(ctx context.Context, client *mongo.Client) (string, error)
118130 }
119131 return "" , errors .New ("topology type must be ReplicaSet or ShardedCluster" )
120132}
133+
134+ func runOnShardsInParallel (
135+ ctx context.Context ,
136+ client * mongo.Client ,
137+ credentials Credentials ,
138+ runCommand func (ctx context.Context , client * mongo.Client ) error ,
139+ ) error {
140+ res , err := GetListShards (ctx , client )
141+ if err != nil {
142+ return err
143+ }
144+
145+ if res .Ok != 1 || len (res .Shards ) == 0 {
146+ return errors .New ("invalid shards" )
147+ }
148+
149+ hosts := getUniqueClusterHosts (res .Shards )
150+ hostsErrors := make ([]error , len (hosts ))
151+ var wg sync.WaitGroup
152+ for idx , host := range hosts {
153+ wg .Add (1 )
154+ go func (i int , h string ) {
155+ defer wg .Done ()
156+
157+ shardOpts := options .Client ().
158+ ApplyURI ("mongodb://" + h ).
159+ SetDirect (true ).
160+ SetAuth (options.Credential {
161+ Username : credentials .Username ,
162+ Password : credentials .Password ,
163+ })
164+
165+ if ! credentials .DisableTls {
166+ tlsConfig , err := shared .CreateTlsConfig (tls .VersionTLS12 , credentials .RootCa , "" , credentials .TlsHost , false )
167+ if err != nil {
168+ hostsErrors [i ] = fmt .Errorf ("host %s TLS config error: %w" , h , err )
169+ return
170+ }
171+ shardOpts .SetTLSConfig (tlsConfig )
172+ }
173+
174+ shardClient , err := mongo .Connect (shardOpts )
175+ if err != nil {
176+ hostsErrors [i ] = fmt .Errorf ("host %s connect error: %w" , h , err )
177+ return
178+ }
179+ defer shardClient .Disconnect (ctx ) //nolint:errcheck
180+
181+ if err := runCommand (ctx , shardClient ); err != nil {
182+ hostsErrors [i ] = fmt .Errorf ("host %s command error: %w" , h , err )
183+ return
184+ }
185+ }(idx , host )
186+ }
187+ wg .Wait ()
188+
189+ for _ , err = range hostsErrors {
190+ if err != nil {
191+ return err
192+ }
193+ }
194+
195+ return nil
196+ }
197+
198+ func getUniqueClusterHosts (shards []Shard ) []string {
199+ hostSet := make (map [string ]bool )
200+ for _ , shard := range shards {
201+ hosts := shard .Host
202+ if slashIdx := strings .Index (hosts , "/" ); slashIdx != - 1 {
203+ hosts = hosts [slashIdx + 1 :]
204+ }
205+
206+ for _ , host := range strings .Split (hosts , "," ) {
207+ host = strings .TrimSpace (host )
208+ if host != "" {
209+ hostSet [host ] = true
210+ }
211+ }
212+ }
213+
214+ hosts := make ([]string , 0 , len (hostSet ))
215+ for host := range hostSet {
216+ hosts = append (hosts , host )
217+ }
218+
219+ return hosts
220+ }
0 commit comments