Skip to content

Commit a1b8b46

Browse files
committed
updates
1 parent 3ddf46a commit a1b8b46

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

mongo/integration/mtest/mongotest.go

+25-7
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,10 @@ type T struct {
128128
succeeded []*event.CommandSucceededEvent
129129
failed []*event.CommandFailedEvent
130130

131-
Client *mongo.Client
132-
DB *mongo.Database
133-
Coll *mongo.Collection
131+
Client *mongo.Client
132+
fpClient *mongo.Client
133+
DB *mongo.Database
134+
Coll *mongo.Collection
134135
}
135136

136137
func newT(wrapped *testing.T, opts ...*Options) *T {
@@ -405,7 +406,9 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
405406
t.clientOpts = opts
406407
}
407408

408-
_ = t.Client.Disconnect(context.Background())
409+
if len(t.failPointNames) == 0 {
410+
_ = t.Client.Disconnect(context.Background())
411+
}
409412
t.createTestClient()
410413
t.DB = t.Client.Database(t.dbName)
411414
t.Coll = t.DB.Collection(t.collName, t.collOpts)
@@ -559,7 +562,10 @@ func (t *T) SetFailPoint(fp FailPoint) {
559562
}
560563
}
561564

562-
if err := SetFailPoint(fp, t.Client); err != nil {
565+
if t.fpClient == nil {
566+
t.fpClient = t.Client
567+
}
568+
if err := SetFailPoint(fp, t.fpClient); err != nil {
563569
t.Fatal(err)
564570
}
565571
t.failPointNames = append(t.failPointNames, fp.ConfigureFailPoint)
@@ -570,7 +576,10 @@ func (t *T) SetFailPoint(fp FailPoint) {
570576
// the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this
571577
// test has run.
572578
func (t *T) SetFailPointFromDocument(fp bson.Raw) {
573-
if err := SetRawFailPoint(fp, t.Client); err != nil {
579+
if t.fpClient == nil {
580+
t.fpClient = t.Client
581+
}
582+
if err := SetRawFailPoint(fp, t.fpClient); err != nil {
574583
t.Fatal(err)
575584
}
576585

@@ -586,7 +595,16 @@ func (t *T) TrackFailPoint(fpName string) {
586595

587596
// ClearFailPoints disables all previously set failpoints for this test.
588597
func (t *T) ClearFailPoints() {
589-
db := t.Client.Database("admin")
598+
client := t.fpClient
599+
if client == nil {
600+
client = t.Client
601+
} else {
602+
defer func() {
603+
// _ = t.fpClient.Disconnect(context.Background())
604+
t.fpClient = nil
605+
}()
606+
}
607+
db := client.Database("admin")
590608
for _, fp := range t.failPointNames {
591609
cmd := bson.D{
592610
{"configureFailPoint", fp},

0 commit comments

Comments
 (0)