@@ -2,10 +2,15 @@ package mongo
22
33import (
44 "context"
5+ "crypto/tls"
56 "errors"
67 "fmt"
8+ "strings"
9+ "sync"
710
11+ "github.com/PeerDB-io/peerdb/flow/shared"
812 "go.mongodb.org/mongo-driver/v2/mongo"
13+ "go.mongodb.org/mongo-driver/v2/mongo/options"
914)
1015
1116const (
@@ -16,7 +21,15 @@ const (
1621 ShardedCluster = "ShardedCluster"
1722)
1823
19- func ValidateServerCompatibility (ctx context.Context , client * mongo.Client ) error {
24+ type Credentials struct {
25+ RootCa * string
26+ Username string
27+ Password string
28+ TlsHost string
29+ DisableTls bool
30+ }
31+
32+ func ValidateServerCompatibility (ctx context.Context , client * mongo.Client , credentials Credentials ) error {
2033 buildInfo , err := GetBuildInfo (ctx , client )
2134 if err != nil {
2235 return err
@@ -48,8 +61,7 @@ func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) erro
4861 if topologyType == ReplicaSet {
4962 return validateStorageEngine (ctx , client )
5063 } else {
51- // TODO: run validation on shard
52- return nil
64+ return runOnShardsInParallel (ctx , client , credentials , validateStorageEngine )
5365 }
5466}
5567
@@ -79,7 +91,7 @@ func ValidateUserRoles(ctx context.Context, client *mongo.Client) error {
7991 return nil
8092}
8193
82- func ValidateOplogRetention (ctx context.Context , client * mongo.Client ) error {
94+ func ValidateOplogRetention (ctx context.Context , client * mongo.Client , credentials Credentials ) error {
8395 validateOplogRetention := func (instanceCtx context.Context , instanceClient * mongo.Client ) error {
8496 ss , err := GetServerStatus (instanceCtx , instanceClient )
8597 if err != nil {
@@ -100,8 +112,7 @@ func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error {
100112 if topology == ReplicaSet {
101113 return validateOplogRetention (ctx , client )
102114 } else {
103- // TODO: run validation on shard
104- return nil
115+ return runOnShardsInParallel (ctx , client , credentials , validateOplogRetention )
105116 }
106117}
107118
@@ -124,3 +135,91 @@ func GetTopologyType(ctx context.Context, client *mongo.Client) (string, error)
124135 }
125136 return "" , errors .New ("topology type must be ReplicaSet or ShardedCluster" )
126137}
138+
139+ func runOnShardsInParallel (
140+ ctx context.Context ,
141+ client * mongo.Client ,
142+ credentials Credentials ,
143+ runCommand func (ctx context.Context , client * mongo.Client ) error ,
144+ ) error {
145+ res , err := GetListShards (ctx , client )
146+ if err != nil {
147+ return err
148+ }
149+
150+ if res .Ok != 1 || len (res .Shards ) == 0 {
151+ return errors .New ("invalid shards" )
152+ }
153+
154+ hosts := getUniqueClusterHosts (res .Shards )
155+ hostsErrors := make ([]error , len (hosts ))
156+ var wg sync.WaitGroup
157+ for idx , host := range hosts {
158+ wg .Add (1 )
159+ go func (i int , h string ) {
160+ defer wg .Done ()
161+
162+ shardOpts := options .Client ().
163+ ApplyURI ("mongodb://" + h ).
164+ SetDirect (true ).
165+ SetAuth (options.Credential {
166+ Username : credentials .Username ,
167+ Password : credentials .Password ,
168+ })
169+
170+ if ! credentials .DisableTls {
171+ tlsConfig , err := shared .CreateTlsConfig (tls .VersionTLS12 , credentials .RootCa , "" , credentials .TlsHost , false )
172+ if err != nil {
173+ hostsErrors [i ] = fmt .Errorf ("host %s TLS config error: %w" , h , err )
174+ return
175+ }
176+ shardOpts .SetTLSConfig (tlsConfig )
177+ }
178+
179+ shardClient , err := mongo .Connect (shardOpts )
180+ if err != nil {
181+ hostsErrors [i ] = fmt .Errorf ("host %s connect error: %w" , h , err )
182+ return
183+ }
184+ defer shardClient .Disconnect (ctx ) //nolint:errcheck
185+
186+ if err := runCommand (ctx , shardClient ); err != nil {
187+ hostsErrors [i ] = fmt .Errorf ("host %s command error: %w" , h , err )
188+ return
189+ }
190+ }(idx , host )
191+ }
192+ wg .Wait ()
193+
194+ for _ , err = range hostsErrors {
195+ if err != nil {
196+ return err
197+ }
198+ }
199+
200+ return nil
201+ }
202+
203+ func getUniqueClusterHosts (shards []Shard ) []string {
204+ hostSet := make (map [string ]bool )
205+ for _ , shard := range shards {
206+ hosts := shard .Host
207+ if slashIdx := strings .Index (hosts , "/" ); slashIdx != - 1 {
208+ hosts = hosts [slashIdx + 1 :]
209+ }
210+
211+ for _ , host := range strings .Split (hosts , "," ) {
212+ host = strings .TrimSpace (host )
213+ if host != "" {
214+ hostSet [host ] = true
215+ }
216+ }
217+ }
218+
219+ hosts := make ([]string , 0 , len (hostSet ))
220+ for host := range hostSet {
221+ hosts = append (hosts , host )
222+ }
223+
224+ return hosts
225+ }
0 commit comments