Skip to content

Commit 79fa2e1

Browse files
RUST-1420 Cache AWS credentials received from endpoints (mongodb#905)
1 parent a018a87 commit 79fa2e1

File tree

7 files changed

+242
-16
lines changed

7 files changed

+242
-16
lines changed

.evergreen/MSRV-Cargo.lock

Lines changed: 8 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.evergreen/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ functions:
167167
working_dir: "src"
168168
script: |
169169
${PREPARE_SHELL}
170-
ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh
170+
ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
171171
172172
"run aws auth test with assume role credentials":
173173
- command: shell.exec
@@ -205,7 +205,7 @@ functions:
205205
working_dir: "src"
206206
script: |
207207
${PREPARE_SHELL}
208-
ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh
208+
ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
209209
210210
"run aws auth test with aws EC2 credentials":
211211
- command: shell.exec
@@ -245,7 +245,7 @@ functions:
245245
working_dir: "src"
246246
script: |
247247
${PREPARE_SHELL}
248-
ASYNC_RUNTIME=${ASYNC_RUNTIME} PROJECT_DIRECTORY=${PROJECT_DIRECTORY} .evergreen/run-aws-tests.sh
248+
ASYNC_RUNTIME=${ASYNC_RUNTIME} PROJECT_DIRECTORY=${PROJECT_DIRECTORY} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
249249
250250
"run aws auth test with aws credentials and session token as environment variables":
251251
- command: shell.exec
@@ -267,7 +267,7 @@ functions:
267267
working_dir: "src"
268268
script: |
269269
${PREPARE_SHELL}
270-
ASYNC_RUNTIME=${ASYNC_RUNTIME} .evergreen/run-aws-tests.sh
270+
ASYNC_RUNTIME=${ASYNC_RUNTIME} SKIP_CREDENTIAL_CACHING_TESTS=1 .evergreen/run-aws-tests.sh
271271
272272
"run aws ECS auth test":
273273
- command: shell.exec

.evergreen/run-aws-tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ set -o errexit
4444

4545
source ./.evergreen/configure-rust.sh
4646

47-
RUST_BACKTRACE=1 cargo test --features aws-auth auth_aws::auth_aws
47+
RUST_BACKTRACE=1 cargo test --features aws-auth auth_aws
4848
RUST_BACKTRACE=1 cargo test --features aws-auth lambda_examples::auth::test_handler

src/bson_util/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,17 @@ pub(crate) fn serialize_result_error_as_string<S: Serializer, T: Serialize>(
235235
.serialize(serializer)
236236
}
237237

238+
#[cfg(feature = "aws-auth")]
239+
pub(crate) fn deserialize_datetime_option_from_double<'de, D>(
240+
deserializer: D,
241+
) -> std::result::Result<Option<bson::DateTime>, D::Error>
242+
where
243+
D: Deserializer<'de>,
244+
{
245+
let millis = f64::deserialize(deserializer)? * 1000.0;
246+
Ok(Some(bson::DateTime::from_millis(millis as i64)))
247+
}
248+
238249
#[cfg(test)]
239250
mod test {
240251
use crate::bson_util::num_decimal_digits;

src/client/auth/aws.rs

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
use std::{fs::File, io::Read};
1+
use std::{fs::File, io::Read, time::Duration};
22

33
use chrono::{offset::Utc, DateTime};
44
use hmac::Hmac;
5+
use lazy_static::lazy_static;
56
use rand::distributions::{Alphanumeric, DistString};
67
use serde::Deserialize;
78
use sha2::{Digest, Sha256};
9+
use tokio::sync::Mutex;
810

911
use crate::{
1012
bson::{doc, rawdoc, spec::BinarySubtype, Binary, Bson, Document},
13+
bson_util::deserialize_datetime_option_from_double,
1114
client::{
1215
auth::{
1316
self,
@@ -27,12 +30,31 @@ const AWS_EC2_IP: &str = "169.254.169.254";
2730
const AWS_LONG_DATE_FMT: &str = "%Y%m%dT%H%M%SZ";
2831
const MECH_NAME: &str = "MONGODB-AWS";
2932

33+
lazy_static! {
34+
static ref CACHED_CREDENTIAL: Mutex<Option<AwsCredential>> = Mutex::new(None);
35+
}
36+
3037
/// Performs MONGODB-AWS authentication for a given stream.
3138
pub(super) async fn authenticate_stream(
3239
conn: &mut Connection,
3340
credential: &Credential,
3441
server_api: Option<&ServerApi>,
3542
http_client: &HttpClient,
43+
) -> Result<()> {
44+
match authenticate_stream_inner(conn, credential, server_api, http_client).await {
45+
Ok(()) => Ok(()),
46+
Err(error) => {
47+
*CACHED_CREDENTIAL.lock().await = None;
48+
Err(error)
49+
}
50+
}
51+
}
52+
53+
async fn authenticate_stream_inner(
54+
conn: &mut Connection,
55+
credential: &Credential,
56+
server_api: Option<&ServerApi>,
57+
http_client: &HttpClient,
3658
) -> Result<()> {
3759
let source = match credential.source.as_deref() {
3860
Some("$external") | None => "$external",
@@ -68,7 +90,23 @@ pub(super) async fn authenticate_stream(
6890
let server_first = ServerFirst::parse(server_first_response.auth_response_body(MECH_NAME)?)?;
6991
server_first.validate(&nonce)?;
7092

71-
let aws_credential = AwsCredential::get(credential, http_client).await?;
93+
let aws_credential = {
94+
// Limit scope of this variable to avoid holding onto the lock for the duration of
95+
// authenticate_stream.
96+
let cached_credential = CACHED_CREDENTIAL.lock().await;
97+
match *cached_credential {
98+
Some(ref aws_credential) if !aws_credential.is_expired() => aws_credential.clone(),
99+
_ => {
100+
// From the spec: the driver MUST not place a lock on making a request.
101+
drop(cached_credential);
102+
let aws_credential = AwsCredential::get(credential, http_client).await?;
103+
if aws_credential.expiration.is_some() {
104+
*CACHED_CREDENTIAL.lock().await = Some(aws_credential.clone());
105+
}
106+
aws_credential
107+
}
108+
}
109+
};
72110

73111
let date = Utc::now();
74112

@@ -117,7 +155,7 @@ pub(super) async fn authenticate_stream(
117155
}
118156

119157
/// Contains the credentials for MONGODB-AWS authentication.
120-
#[derive(Debug, Deserialize)]
158+
#[derive(Clone, Debug, Deserialize)]
121159
#[serde(rename_all = "PascalCase")]
122160
pub(crate) struct AwsCredential {
123161
access_key_id: String,
@@ -126,6 +164,9 @@ pub(crate) struct AwsCredential {
126164

127165
#[serde(alias = "Token")]
128166
session_token: Option<String>,
167+
168+
#[serde(default, deserialize_with = "deserialize_datetime_option_from_double")]
169+
expiration: Option<bson::DateTime>,
129170
}
130171

131172
impl AwsCredential {
@@ -157,6 +198,7 @@ impl AwsCredential {
157198
access_key_id: access_key,
158199
secret_access_key: secret_key,
159200
session_token,
201+
expiration: None,
160202
});
161203
}
162204

@@ -419,6 +461,16 @@ impl AwsCredential {
419461
pub(crate) fn session_token(&self) -> Option<&str> {
420462
self.session_token.as_deref()
421463
}
464+
465+
fn is_expired(&self) -> bool {
466+
match self.expiration {
467+
Some(expiration) => {
468+
expiration.saturating_duration_since(bson::DateTime::now())
469+
< Duration::from_secs(5 * 60)
470+
}
471+
None => true,
472+
}
473+
}
422474
}
423475

424476
/// The response from the server to the `saslStart` command in a MONGODB-AWS authentication attempt.
@@ -496,3 +548,45 @@ impl ServerFirst {
496548
}
497549
}
498550
}
551+
552+
#[cfg(test)]
553+
pub(crate) mod test_utils {
554+
use super::{AwsCredential, CACHED_CREDENTIAL};
555+
556+
pub(crate) async fn cached_credential() -> Option<AwsCredential> {
557+
CACHED_CREDENTIAL.lock().await.clone()
558+
}
559+
560+
pub(crate) async fn clear_cached_credential() {
561+
*CACHED_CREDENTIAL.lock().await = None;
562+
}
563+
564+
pub(crate) async fn poison_cached_credential() {
565+
CACHED_CREDENTIAL
566+
.lock()
567+
.await
568+
.as_mut()
569+
.unwrap()
570+
.access_key_id = "bad".into();
571+
}
572+
573+
pub(crate) async fn cached_access_key_id() -> String {
574+
cached_credential().await.unwrap().access_key_id
575+
}
576+
577+
pub(crate) async fn cached_secret_access_key() -> String {
578+
cached_credential().await.unwrap().secret_access_key
579+
}
580+
581+
pub(crate) async fn cached_session_token() -> Option<String> {
582+
cached_credential().await.unwrap().session_token
583+
}
584+
585+
pub(crate) async fn cached_expiration() -> bson::DateTime {
586+
cached_credential().await.unwrap().expiration.unwrap()
587+
}
588+
589+
pub(crate) async fn set_cached_expiration(expiration: bson::DateTime) {
590+
CACHED_CREDENTIAL.lock().await.as_mut().unwrap().expiration = Some(expiration);
591+
}
592+
}

src/test/auth_aws.rs

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
use bson::Document;
2-
use tokio::sync::RwLockReadGuard;
1+
use std::env::{remove_var, set_var, var};
2+
3+
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
4+
5+
use crate::{bson::Document, client::auth::aws::test_utils::*, test::DEFAULT_URI, Client};
36

47
use super::{TestClient, LOCK};
58

@@ -13,3 +16,118 @@ async fn auth_aws() {
1316

1417
coll.find_one(None, None).await.unwrap();
1518
}
19+
20+
// The TestClient performs operations upon creation that trigger authentication, so the credential
21+
// caching tests use a regular client instead to avoid that noise.
22+
async fn get_client() -> Client {
23+
Client::with_uri_str(DEFAULT_URI.clone()).await.unwrap()
24+
}
25+
26+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
27+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
28+
async fn credential_caching() {
29+
// This test should only be run when authenticating using AWS endpoints.
30+
if var("SKIP_CREDENTIAL_CACHING_TESTS").is_ok() {
31+
return;
32+
}
33+
34+
let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;
35+
36+
clear_cached_credential().await;
37+
38+
let client = get_client().await;
39+
let coll = client.database("aws").collection::<Document>("somecoll");
40+
coll.find_one(None, None).await.unwrap();
41+
assert!(cached_credential().await.is_some());
42+
43+
let now = bson::DateTime::now();
44+
set_cached_expiration(now).await;
45+
46+
let client = get_client().await;
47+
let coll = client.database("aws").collection::<Document>("somecoll");
48+
coll.find_one(None, None).await.unwrap();
49+
assert!(cached_credential().await.is_some());
50+
assert!(cached_expiration().await > now);
51+
52+
poison_cached_credential().await;
53+
54+
let client = get_client().await;
55+
let coll = client.database("aws").collection::<Document>("somecoll");
56+
match coll.find_one(None, None).await {
57+
Ok(_) => panic!(
58+
"find one should have failed with authentication error due to poisoned cached \
59+
credential"
60+
),
61+
Err(error) => assert!(error.is_auth_error()),
62+
}
63+
assert!(cached_credential().await.is_none());
64+
65+
coll.find_one(None, None).await.unwrap();
66+
assert!(cached_credential().await.is_some());
67+
}
68+
69+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
70+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
71+
async fn credential_caching_environment_vars() {
72+
// This test should only be run when authenticating using AWS endpoints.
73+
if var("SKIP_CREDENTIAL_CACHING_TESTS").is_ok() {
74+
return;
75+
}
76+
77+
let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;
78+
79+
clear_cached_credential().await;
80+
81+
let client = get_client().await;
82+
let coll = client.database("aws").collection::<Document>("somecoll");
83+
coll.find_one(None, None).await.unwrap();
84+
assert!(cached_credential().await.is_some());
85+
86+
set_var("AWS_ACCESS_KEY_ID", cached_access_key_id().await);
87+
set_var("AWS_SECRET_ACCESS_KEY", cached_secret_access_key().await);
88+
if let Some(session_token) = cached_session_token().await {
89+
set_var("AWS_SESSION_TOKEN", session_token);
90+
}
91+
clear_cached_credential().await;
92+
93+
let client = get_client().await;
94+
let coll = client.database("aws").collection::<Document>("somecoll");
95+
coll.find_one(None, None).await.unwrap();
96+
assert!(cached_credential().await.is_none());
97+
98+
set_var("AWS_ACCESS_KEY_ID", "bad");
99+
set_var("AWS_SECRET_ACCESS_KEY", "bad");
100+
set_var("AWS_SESSION_TOKEN", "bad");
101+
102+
let client = get_client().await;
103+
let coll = client.database("aws").collection::<Document>("somecoll");
104+
match coll.find_one(None, None).await {
105+
Ok(_) => panic!(
106+
"find one should have failed with authentication error due to poisoned environment \
107+
variables"
108+
),
109+
Err(error) => assert!(error.is_auth_error()),
110+
}
111+
112+
remove_var("AWS_ACCESS_KEY_ID");
113+
remove_var("AWS_SECRET_ACCESS_KEY");
114+
remove_var("AWS_SESSION_TOKEN");
115+
clear_cached_credential().await;
116+
117+
let client = get_client().await;
118+
let coll = client.database("aws").collection::<Document>("somecoll");
119+
coll.find_one(None, None).await.unwrap();
120+
assert!(cached_credential().await.is_some());
121+
122+
set_var("AWS_ACCESS_KEY_ID", "bad");
123+
set_var("AWS_SECRET_ACCESS_KEY", "bad");
124+
set_var("AWS_SESSION_TOKEN", "bad");
125+
126+
let client = get_client().await;
127+
let coll = client.database("aws").collection::<Document>("somecoll");
128+
coll.find_one(None, None).await.unwrap();
129+
130+
remove_var("AWS_ACCESS_KEY_ID");
131+
remove_var("AWS_SECRET_ACCESS_KEY");
132+
remove_var("AWS_SESSION_TOKEN");
133+
}

src/test/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#[cfg(all(not(feature = "sync"), not(feature = "tokio-sync")))]
22
mod atlas_connectivity;
33
mod atlas_planned_maintenance_testing;
4+
#[cfg(feature = "aws-auth")]
45
mod auth_aws;
56
mod change_stream;
67
mod client;

0 commit comments

Comments
 (0)