Skip to content

Commit 117f461

Browse files
committed
update test case
1 parent a23e594 commit 117f461

File tree

11 files changed

+316
-196
lines changed

11 files changed

+316
-196
lines changed

internal/credproviders/aws_provider.go

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,15 @@ import (
1010
"context"
1111

1212
"go.mongodb.org/mongo-driver/v2/internal/aws/credentials"
13+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
1314
)
1415

1516
const awsProviderName = "AwsProvider"
1617

17-
// AwsCredentialsProvider is a function that retrieves AWS credentials.
18-
type AwsCredentialsProvider func(context.Context) (AwsCredentials, error)
19-
20-
// AwsCredentials represents AWS credentials with an expiration callback.
21-
type AwsCredentials struct {
22-
AccessKeyID string
23-
SecretAccessKey string
24-
SessionToken string
25-
ExpirationCallback func() bool
26-
}
27-
2818
// AwsProvider retrieves credentials from the given AWS credentials provider.
2919
type AwsProvider struct {
30-
credentials *AwsCredentials
31-
32-
Provider AwsCredentialsProvider
20+
credentials *driver.Credentials
21+
Provider func(context.Context) (driver.Credentials, error)
3322
}
3423

3524
// Retrieve retrieves the keys from the given AWS credentials provider.

internal/integration/client_side_encryption_prose_test.go

Lines changed: 107 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import (
2828
"go.mongodb.org/mongo-driver/v2/bson"
2929
"go.mongodb.org/mongo-driver/v2/event"
3030
"go.mongodb.org/mongo-driver/v2/internal/assert"
31-
"go.mongodb.org/mongo-driver/v2/internal/credproviders"
3231
"go.mongodb.org/mongo-driver/v2/internal/handshake"
3332
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
3433
"go.mongodb.org/mongo-driver/v2/internal/integtest"
@@ -3146,144 +3145,125 @@ func TestClientSideEncryptionProse(t *testing.T) {
31463145
})
31473146
}
31483147
})
3148+
}
31493149

3150-
mt.RunOpts("26. custom AWS credentials", qeRunOpts22, func(mt *mtest.T) {
3151-
mt.Run("Case 1: ClientEncryption with credentialProviders and incorrect kmsProviders", func(mt *mtest.T) {
3152-
opts := options.Client().ApplyURI(mtest.ClusterURI())
3153-
integtest.AddTestServerAPIVersion(opts)
3154-
keyVaultClient, err := mongo.Connect(opts)
3155-
assert.NoErrorf(mt, err, "error on Connect: %v", err)
3150+
func TestCustomAwsCredentialsProse(t *testing.T) {
3151+
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
31563152

3157-
ceo := options.ClientEncryption().
3158-
SetKeyVaultNamespace("keyvault.datakeys").
3159-
SetKmsProviders(map[string]map[string]any{
3160-
"aws": {
3161-
"accessKeyId": awsAccessKeyID,
3162-
"secretAccessKey": awsSecretAccessKey,
3163-
},
3164-
}).
3165-
SetCredentialProviders(map[string]options.CredentialsProvider{
3166-
"aws": func(ctx context.Context) (options.Credentials, error) {
3167-
var cred options.Credentials
3168-
provider := credproviders.NewEnvProvider()
3169-
c, err := provider.Retrieve(ctx)
3170-
if err != nil {
3171-
return cred, err
3172-
}
3173-
cred.AccessKeyID = c.AccessKeyID
3174-
cred.SecretAccessKey = c.SecretAccessKey
3175-
cred.SessionToken = c.SessionToken
3176-
cred.ExpirationCallback = provider.IsExpired
3177-
return cred, nil
3178-
},
3179-
})
3180-
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3181-
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
3153+
mt.Run("Case 1: ClientEncryption with credentialProviders and incorrect kmsProviders", func(mt *mtest.T) {
3154+
opts := options.Client().ApplyURI(mtest.ClusterURI())
3155+
integtest.AddTestServerAPIVersion(opts)
3156+
keyVaultClient, err := mongo.Connect(opts)
3157+
assert.NoErrorf(mt, err, "error on Connect: %v", err)
31823158

3183-
dkOpts := options.DataKey()
3184-
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3185-
assert.Error(mt, err, "expected an error")
3186-
})
3187-
mt.Run("Case 2: ClientEncryption with credentialProviders works", func(mt *mtest.T) {
3188-
opts := options.Client().ApplyURI(mtest.ClusterURI())
3189-
integtest.AddTestServerAPIVersion(opts)
3190-
keyVaultClient, err := mongo.Connect(opts)
3191-
assert.NoErrorf(mt, err, "error on Connect: %v", err)
3159+
ceo := options.ClientEncryption().
3160+
SetKeyVaultNamespace("keyvault.datakeys").
3161+
SetKmsProviders(map[string]map[string]any{
3162+
"aws": {
3163+
"accessKeyId": awsAccessKeyID,
3164+
"secretAccessKey": awsSecretAccessKey,
3165+
},
3166+
}).
3167+
SetCredentialProviders(map[string]options.CredentialsProvider{
3168+
"aws": func(ctx context.Context) (options.Credentials, error) {
3169+
return options.Credentials{}, nil
3170+
},
3171+
})
3172+
_, err = mongo.NewClientEncryption(keyVaultClient, ceo)
3173+
assert.ErrorContains(mt, err, "can only provide a custom AWS credential provider",
3174+
"unexpected error: %v", err)
3175+
})
31923176

3193-
var calledCount int
3194-
ceo := options.ClientEncryption().
3195-
SetKeyVaultNamespace("keyvault.datakeys").
3196-
SetKmsProviders(map[string]map[string]any{
3197-
"aws": map[string]any{},
3198-
}).
3199-
SetCredentialProviders(map[string]options.CredentialsProvider{
3200-
"aws": func(_ context.Context) (options.Credentials, error) {
3201-
calledCount++
3202-
return options.Credentials{
3203-
AccessKeyID: awsAccessKeyID,
3204-
SecretAccessKey: awsSecretAccessKey,
3205-
ExpirationCallback: func() bool { return false },
3206-
}, nil
3207-
},
3208-
})
3209-
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3210-
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
3177+
mt.Run("Case 2: ClientEncryption with credentialProviders works", func(mt *mtest.T) {
3178+
opts := options.Client().ApplyURI(mtest.ClusterURI())
3179+
integtest.AddTestServerAPIVersion(opts)
3180+
keyVaultClient, err := mongo.Connect(opts)
3181+
assert.NoErrorf(mt, err, "error on Connect: %v", err)
32113182

3212-
dkOpts := options.DataKey()
3213-
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3214-
assert.NoErrorf(mt, err, "unexpected error %v", err)
3215-
assert.Equal(mt, 1, calledCount, "expected credential provider to be called once")
3216-
})
3183+
var calledCount int
3184+
ceo := options.ClientEncryption().
3185+
SetKeyVaultNamespace("keyvault.datakeys").
3186+
SetKmsProviders(map[string]map[string]any{
3187+
"aws": map[string]any{},
3188+
}).
3189+
SetCredentialProviders(map[string]options.CredentialsProvider{
3190+
"aws": func(_ context.Context) (options.Credentials, error) {
3191+
calledCount++
3192+
return options.Credentials{
3193+
AccessKeyID: awsAccessKeyID,
3194+
SecretAccessKey: awsSecretAccessKey,
3195+
ExpirationCallback: func() bool { return false },
3196+
}, nil
3197+
},
3198+
})
3199+
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3200+
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
32173201

3218-
mt.Run("Case 3: AutoEncryptionOpts with credentialProviders and incorrect kmsProviders", func(mt *mtest.T) {
3219-
aeo := options.AutoEncryption().
3220-
SetKeyVaultNamespace("keyvault.datakeys").
3221-
SetKmsProviders(map[string]map[string]any{
3222-
"aws": {
3223-
"accessKeyId": awsAccessKeyID,
3224-
"secretAccessKey": awsSecretAccessKey,
3225-
},
3226-
}).
3227-
SetCredentialProviders(map[string]options.CredentialsProvider{
3228-
"aws": func(ctx context.Context) (options.Credentials, error) {
3229-
var cred options.Credentials
3230-
provider := credproviders.NewEnvProvider()
3231-
c, err := provider.Retrieve(ctx)
3232-
if err != nil {
3233-
return cred, err
3234-
}
3235-
cred.AccessKeyID = c.AccessKeyID
3236-
cred.SecretAccessKey = c.SecretAccessKey
3237-
cred.SessionToken = c.SessionToken
3238-
cred.ExpirationCallback = provider.IsExpired
3239-
return cred, nil
3240-
},
3241-
})
3242-
co := options.Client().SetAutoEncryptionOptions(aeo).ApplyURI(mtest.ClusterURI())
3243-
integtest.AddTestServerAPIVersion(co)
3244-
_, err := mongo.Connect(co)
3245-
assert.Error(mt, err, "expected an error")
3202+
dkOpts := options.DataKey().SetMasterKey(bson.D{
3203+
{"region", "us-east-1"},
3204+
{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
32463205
})
3206+
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3207+
assert.NoErrorf(mt, err, "unexpected error %v", err)
3208+
assert.Equal(mt, 1, calledCount, "expected credential provider to be called once")
3209+
})
32473210

3248-
mt.Run("Case 4: ClientEncryption with credentialProviders and valid environment variables", func(mt *mtest.T) {
3249-
mt.Setenv("AWS_ACCESS_KEY_ID", os.Getenv("FLE_AWS_SECRET_ACCESS_KEY"))
3250-
mt.Setenv("AWS_SECRET_ACCESS_KEY", os.Getenv("FLE_AWS_ACCESS_KEY_ID"))
3211+
mt.Run("Case 3: AutoEncryptionOpts with credentialProviders and incorrect kmsProviders", func(mt *mtest.T) {
3212+
aeo := options.AutoEncryption().
3213+
SetKeyVaultNamespace("keyvault.datakeys").
3214+
SetKmsProviders(map[string]map[string]any{
3215+
"aws": {
3216+
"accessKeyId": awsAccessKeyID,
3217+
"secretAccessKey": awsSecretAccessKey,
3218+
},
3219+
}).
3220+
SetCredentialProviders(map[string]options.CredentialsProvider{
3221+
"aws": func(ctx context.Context) (options.Credentials, error) {
3222+
return options.Credentials{}, nil
3223+
},
3224+
})
3225+
co := options.Client().SetAutoEncryptionOptions(aeo).ApplyURI(mtest.ClusterURI())
3226+
integtest.AddTestServerAPIVersion(co)
3227+
_, err := mongo.Connect(co)
3228+
assert.ErrorContainsf(mt, err, "can only provide a custom AWS credential provider",
3229+
"unexpected error: %v", err)
3230+
})
32513231

3252-
opts := options.Client().ApplyURI(mtest.ClusterURI())
3253-
integtest.AddTestServerAPIVersion(opts)
3254-
keyVaultClient, err := mongo.Connect(opts)
3255-
assert.NoErrorf(mt, err, "error on Connect: %v", err)
3232+
mt.Run("Case 4: ClientEncryption with credentialProviders and valid environment variables", func(mt *mtest.T) {
3233+
mt.Setenv("AWS_ACCESS_KEY_ID", os.Getenv("FLE_AWS_SECRET_ACCESS_KEY"))
3234+
mt.Setenv("AWS_SECRET_ACCESS_KEY", os.Getenv("FLE_AWS_ACCESS_KEY_ID"))
32563235

3257-
ceo := options.ClientEncryption().
3258-
SetKeyVaultNamespace("keyvault.datakeys").
3259-
SetKmsProviders(map[string]map[string]any{
3260-
"aws": {
3261-
"accessKeyId": awsAccessKeyID,
3262-
"secretAccessKey": awsSecretAccessKey,
3263-
},
3264-
}).
3265-
SetCredentialProviders(map[string]options.CredentialsProvider{
3266-
"aws": func(ctx context.Context) (options.Credentials, error) {
3267-
var cred options.Credentials
3268-
provider := credproviders.NewEnvProvider()
3269-
c, err := provider.Retrieve(ctx)
3270-
if err != nil {
3271-
return cred, err
3272-
}
3273-
cred.AccessKeyID = c.AccessKeyID
3274-
cred.SecretAccessKey = c.SecretAccessKey
3275-
cred.SessionToken = c.SessionToken
3276-
cred.ExpirationCallback = provider.IsExpired
3277-
return cred, nil
3278-
},
3279-
})
3280-
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3281-
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
3236+
opts := options.Client().ApplyURI(mtest.ClusterURI())
3237+
integtest.AddTestServerAPIVersion(opts)
3238+
keyVaultClient, err := mongo.Connect(opts)
3239+
assert.NoErrorf(mt, err, "error on Connect: %v", err)
32823240

3283-
dkOpts := options.DataKey()
3284-
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3285-
assert.NoErrorf(mt, err, "unexpected error %v", err)
3241+
var calledCount int
3242+
ceo := options.ClientEncryption().
3243+
SetKeyVaultNamespace("keyvault.datakeys").
3244+
SetKmsProviders(map[string]map[string]any{
3245+
"aws": map[string]any{},
3246+
}).
3247+
SetCredentialProviders(map[string]options.CredentialsProvider{
3248+
"aws": func(ctx context.Context) (options.Credentials, error) {
3249+
calledCount++
3250+
return options.Credentials{
3251+
AccessKeyID: awsAccessKeyID,
3252+
SecretAccessKey: awsSecretAccessKey,
3253+
ExpirationCallback: func() bool { return false },
3254+
}, nil
3255+
},
3256+
})
3257+
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3258+
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
3259+
3260+
dkOpts := options.DataKey().SetMasterKey(bson.D{
3261+
{"region", "us-east-1"},
3262+
{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
32863263
})
3264+
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3265+
assert.NoErrorf(mt, err, "unexpected error %v", err)
3266+
assert.Equal(mt, 1, calledCount, "expected credential provider to be called once")
32873267
})
32883268
}
32893269

internal/test/aws/aws_test.go

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -38,37 +38,3 @@ func TestAWS(t *testing.T) {
3838
t.Logf("FindOne error: %v", err)
3939
}
4040
}
41-
42-
func TestAWSCustomCredentialProviders(t *testing.T) {
43-
uri := os.Getenv("MONGODB_URI")
44-
if uri == "" {
45-
t.Skip("Skipping test: MONGODB_URI environment variable is not set")
46-
}
47-
48-
var calledCount int
49-
awsCredential := options.Credential{
50-
AuthMechanism: "MONGODB-AWS",
51-
AwsCredentialsProvider: func(_ context.Context) (options.Credentials, error) {
52-
calledCount++
53-
return options.Credentials{
54-
AccessKeyID: os.Getenv("AWS_ACCESS_KEY_ID"),
55-
SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"),
56-
ExpirationCallback: func() bool { return false },
57-
}, nil
58-
},
59-
}
60-
client, err := mongo.Connect(options.Client().ApplyURI(uri).SetAuth(awsCredential))
61-
62-
defer func() {
63-
err = client.Disconnect(context.Background())
64-
require.NoError(t, err)
65-
}()
66-
67-
coll := client.Database("aws").Collection("test")
68-
69-
err = coll.FindOne(context.Background(), bson.D{{Key: "x", Value: 1}}).Err()
70-
if err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
71-
t.Logf("FindOne error: %v", err)
72-
}
73-
require.Equalf(t, 1, calledCount, "expected custom AWS credential provider to be called once")
74-
}

mongo/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,8 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt
601601
for k, fn := range opts.CredentialProviders {
602602
if k == "aws" && fn != nil {
603603
providers[k] = &credproviders.AwsProvider{
604-
Provider: func(ctx context.Context) (credproviders.AwsCredentials, error) {
605-
var creds credproviders.AwsCredentials
604+
Provider: func(ctx context.Context) (driver.Credentials, error) {
605+
var creds driver.Credentials
606606
c, err := fn(ctx)
607607
if err != nil {
608608
return creds, err

mongo/client_encryption.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options.
5959
for k, fn := range cea.CredentialProviders {
6060
if k == "aws" && fn != nil {
6161
providers[k] = &credproviders.AwsProvider{
62-
Provider: func(ctx context.Context) (credproviders.AwsCredentials, error) {
63-
var creds credproviders.AwsCredentials
62+
Provider: func(ctx context.Context) (driver.Credentials, error) {
63+
var creds driver.Credentials
6464
c, err := fn(ctx)
6565
if err != nil {
6666
return creds, err

mongo/client_examples_test.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,11 @@ func ExampleConnect_aWS() {
267267
// The order in which the driver searches for credentials is:
268268
//
269269
// 1. Credentials passed through the URI
270-
// 2. Environment variables
271-
// 3. ECS endpoint if and only if AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is
270+
// 2. Custom AWS credential provider
271+
// 3. Environment variables
272+
// 4. ECS endpoint if and only if AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is
272273
// set
273-
// 4. EC2 endpoint
274+
// 5. EC2 endpoint
274275
//
275276
// The following examples set the appropriate credentials via the
276277
// ClientOptions.SetAuth method. All of these credentials can be specified
@@ -352,6 +353,28 @@ func ExampleConnect_aWS() {
352353
panic(err)
353354
}
354355
_ = ecClient
356+
357+
// Custom AWS credential provider
358+
359+
// Applications can authenticate using a custom AWS credential provider as
360+
// well.
361+
credential := options.Credential{
362+
AuthMechanism: "MONGODB-AWS",
363+
AwsCredentialsProvider: func(_ context.Context) (options.Credentials, error) {
364+
return options.Credentials{
365+
AccessKeyID: accessKeyID,
366+
SecretAccessKey: secretAccessKey,
367+
SessionToken: sessionToken,
368+
ExpirationCallback: func() bool { return false },
369+
}, nil
370+
},
371+
}
372+
awsClient, err := mongo.Connect(
373+
options.Client().SetAuth(credential))
374+
if err != nil {
375+
panic(err)
376+
}
377+
_ = awsClient
355378
}
356379

357380
func ExampleConnect_stableAPI() {

0 commit comments

Comments
 (0)