From 0a630dc8102d2d2eea7293c3c8fbbb7d765df639 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Mon, 4 Nov 2024 21:57:57 +0800 Subject: [PATCH] refactor: Migrate aws-v4 to new design (#502) --- core/src/api.rs | 3 +- core/src/context.rs | 16 + core/src/env.rs | 2 +- core/src/fs.rs | 2 +- core/src/http.rs | 2 +- core/src/lib.rs | 1 + core/src/request.rs | 1 + core/src/signer.rs | 6 +- core/src/utils.rs | 73 ++ services/aws-v4/Cargo.toml | 1 + services/aws-v4/benches/aws.rs | 23 +- services/aws-v4/src/{signer.rs => build.rs} | 353 +++--- services/aws-v4/src/config.rs | 27 +- services/aws-v4/src/credential.rs | 1024 ----------------- services/aws-v4/src/key.rs | 47 + services/aws-v4/src/lib.rs | 22 +- services/aws-v4/src/load/assume_role.rs | 175 +++ .../src/load/assume_role_with_web_identity.rs | 137 +++ services/aws-v4/src/load/config.rs | 36 + services/aws-v4/src/load/default.rs | 240 ++++ services/aws-v4/src/load/imds.rs | 153 +++ services/aws-v4/src/load/mod.rs | 16 + services/aws-v4/src/load/utils.rs | 34 + services/aws-v4/tests/main.rs | 278 ++++- 24 files changed, 1409 insertions(+), 1263 deletions(-) create mode 100644 core/src/utils.rs rename services/aws-v4/src/{signer.rs => build.rs} (74%) delete mode 100644 services/aws-v4/src/credential.rs create mode 100644 services/aws-v4/src/key.rs create mode 100644 services/aws-v4/src/load/assume_role.rs create mode 100644 services/aws-v4/src/load/assume_role_with_web_identity.rs create mode 100644 services/aws-v4/src/load/config.rs create mode 100644 services/aws-v4/src/load/default.rs create mode 100644 services/aws-v4/src/load/imds.rs create mode 100644 services/aws-v4/src/load/mod.rs create mode 100644 services/aws-v4/src/load/utils.rs diff --git a/core/src/api.rs b/core/src/api.rs index 863910f..5b84151 100644 --- a/core/src/api.rs +++ b/core/src/api.rs @@ -1,4 +1,3 @@ -use super::SigningRequest; use crate::Context; use std::fmt::Debug; use std::time::Duration; @@ -61,5 +60,5 @@ pub trait Build: Debug + Send + Sync + Unpin + 'static { req: &mut http::request::Parts, key: Option<&Self::Key>, expires_in: Option, - ) -> anyhow::Result; + ) -> anyhow::Result<()>; } diff --git a/core/src/context.rs b/core/src/context.rs index 2f49a62..cc950a4 100644 --- a/core/src/context.rs +++ b/core/src/context.rs @@ -38,12 +38,28 @@ impl Context { self.fs.file_read(path).await } + /// Read the file content entirely in `String`. + pub async fn file_read_as_string(&self, path: &str) -> Result { + let bytes = self.file_read(path).await?; + Ok(String::from_utf8_lossy(&bytes).to_string()) + } + /// Send http request and return the response. #[inline] pub async fn http_send(&self, req: http::Request) -> Result> { self.http.http_send(req).await } + /// Send http request and return the response as string. + pub async fn http_send_as_string( + &self, + req: http::Request, + ) -> Result> { + let (parts, body) = self.http.http_send(req).await?.into_parts(); + let body = String::from_utf8_lossy(&body).to_string(); + Ok(http::Response::from_parts(parts, body)) + } + /// Get the home directory of the current user. #[inline] pub fn home_dir(&self) -> Option { diff --git a/core/src/env.rs b/core/src/env.rs index fa0cc61..1f7bf9b 100644 --- a/core/src/env.rs +++ b/core/src/env.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use std::path::PathBuf; /// Permits parameterizing the home functions via the _from variants -pub trait Env: Debug + 'static { +pub trait Env: Debug + Send + Sync + 'static { /// Get an environment variable. /// /// - Returns `Some(v)` if the environment variable is found and is valid utf-8. diff --git a/core/src/fs.rs b/core/src/fs.rs index 1af14f1..4d3a6e1 100644 --- a/core/src/fs.rs +++ b/core/src/fs.rs @@ -5,7 +5,7 @@ use std::fmt::Debug; /// /// This could be used by `Load` to load the credential from the file. #[async_trait::async_trait] -pub trait FileRead: Debug + 'static { +pub trait FileRead: Debug + Send + Sync + 'static { /// Read the file content entirely in `Vec`. async fn file_read(&self, path: &str) -> Result>; } diff --git a/core/src/http.rs b/core/src/http.rs index da51383..d8fd0a0 100644 --- a/core/src/http.rs +++ b/core/src/http.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; /// For example, fetch IMDS token from AWS or OAuth2 refresh token. This trait is designed /// especially for the signer, please don't use it as a general http client. #[async_trait::async_trait] -pub trait HttpSend: Debug + 'static { +pub trait HttpSend: Debug + Send + Sync + 'static { /// Send http request and return the response. async fn http_send(&self, req: http::Request) -> Result>; } diff --git a/core/src/lib.rs b/core/src/lib.rs index c41c9ac..b19f19f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -5,6 +5,7 @@ pub mod hash; pub mod time; +pub mod utils; mod context; pub use context::Context; diff --git a/core/src/request.rs b/core/src/request.rs index f4ea92f..bfe750a 100644 --- a/core/src/request.rs +++ b/core/src/request.rs @@ -15,6 +15,7 @@ use http::Uri; use std::str::FromStr; /// Signing context for request. +#[derive(Debug)] pub struct SigningRequest { /// HTTP method. pub method: Method, diff --git a/core/src/signer.rs b/core/src/signer.rs index b1ebae3..d776c51 100644 --- a/core/src/signer.rs +++ b/core/src/signer.rs @@ -39,10 +39,8 @@ impl Signer { ctx }; - let signing = self - .builder + self.builder .build(&self.ctx, req, key.as_ref(), expires_in) - .await?; - signing.apply(req) + .await } } diff --git a/core/src/utils.rs b/core/src/utils.rs new file mode 100644 index 0000000..0a70402 --- /dev/null +++ b/core/src/utils.rs @@ -0,0 +1,73 @@ +//! Utility functions and types. + +use std::fmt::Debug; + +/// Redacts a string by replacing all but the first and last three characters with asterisks. +/// +/// - If the input string has fewer than 12 characters, it should be entirely redacted. +/// - If the input string has 12 or more characters, only the first three and the last three. +/// +/// This design is to allow users to distinguish between different redacted strings but avoid +/// leaking sensitive information. +pub struct Redact<'a>(&'a str); + +impl<'a> From<&'a str> for Redact<'a> { + fn from(value: &'a str) -> Self { + Redact(value) + } +} + +impl<'a> From<&'a String> for Redact<'a> { + fn from(value: &'a String) -> Self { + Redact(value.as_str()) + } +} + +impl<'a> From<&'a Option> for Redact<'a> { + fn from(value: &'a Option) -> Self { + match value { + None => Redact(""), + Some(v) => Redact(v), + } + } +} + +impl<'a> Debug for Redact<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let length = self.0.len(); + if length == 0 { + f.write_str("EMPTY") + } else if length < 12 { + f.write_str("***") + } else { + f.write_str(&self.0[..3])?; + f.write_str("***")?; + f.write_str(&self.0[length - 3..]) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_redact() { + let cases = vec![ + ("Short", "***"), + ("Hello World!", "Hel***ld!"), + ("This is a longer string", "Thi***ing"), + ("", "EMPTY"), + ("HelloWorld", "***"), + ]; + + for (input, expected) in cases { + assert_eq!( + format!("{:?}", Redact(input)), + expected, + "Failed on input: {}", + input + ); + } + } +} diff --git a/services/aws-v4/Cargo.toml b/services/aws-v4/Cargo.toml index 1457ce2..1a4514b 100644 --- a/services/aws-v4/Cargo.toml +++ b/services/aws-v4/Cargo.toml @@ -27,6 +27,7 @@ reqwest.workspace = true rust-ini.workspace = true serde.workspace = true serde_json.workspace = true +bytes = "1.7.2" [dev-dependencies] aws-credential-types = "1.1.8" diff --git a/services/aws-v4/benches/aws.rs b/services/aws-v4/benches/aws.rs index 19cfbac..f7932dd 100644 --- a/services/aws-v4/benches/aws.rs +++ b/services/aws-v4/benches/aws.rs @@ -9,12 +9,24 @@ use aws_sigv4::sign::v4::SigningParams; use criterion::criterion_group; use criterion::criterion_main; use criterion::Criterion; +use once_cell::sync::Lazy; +use reqsign_aws_v4::Builder as AwsV4Builder; use reqsign_aws_v4::Credential as AwsCredential; -use reqsign_aws_v4::Signer as AwsV4Signer; +use reqsign_core::{Build, Context}; +use reqsign_file_read_tokio::TokioFileRead; +use reqsign_http_send_reqwest::ReqwestHttpSend; criterion_group!(benches, bench); criterion_main!(benches); +static RUNTIME: Lazy = Lazy::new(|| { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .expect("must success") +}); + pub fn bench(c: &mut Criterion) { let mut group = c.benchmark_group("aws_v4"); @@ -25,9 +37,10 @@ pub fn bench(c: &mut Criterion) { ..Default::default() }; - let s = AwsV4Signer::new("s3", "test"); + let s = AwsV4Builder::new("s3", "test"); + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); - b.iter(|| { + b.to_async(&*RUNTIME).iter(|| async { let mut req = http::Request::new(""); *req.method_mut() = http::Method::GET; *req.uri_mut() = "http://127.0.0.1:9000/hello" @@ -35,7 +48,9 @@ pub fn bench(c: &mut Criterion) { .expect("url must be valid"); let (mut parts, _) = req.into_parts(); - s.sign(&mut parts, &cred).expect("must success") + s.build(&ctx, &mut parts, Some(&cred), None) + .await + .expect("must success") }) }); diff --git a/services/aws-v4/src/signer.rs b/services/aws-v4/src/build.rs similarity index 74% rename from services/aws-v4/src/signer.rs rename to services/aws-v4/src/build.rs index bf500e0..313c01e 100644 --- a/services/aws-v4/src/signer.rs +++ b/services/aws-v4/src/build.rs @@ -1,48 +1,36 @@ -//! AWS service sigv4 signer - -use std::fmt::Debug; +use crate::constants::{ + AWS_QUERY_ENCODE_SET, X_AMZ_CONTENT_SHA_256, X_AMZ_DATE, X_AMZ_SECURITY_TOKEN, +}; +use crate::Credential; +use async_trait::async_trait; +use http::request::Parts; +use http::{header, HeaderValue}; +use log::debug; +use percent_encoding::{percent_decode_str, utf8_percent_encode}; +use reqsign_core::hash::{hex_hmac_sha256, hex_sha256, hmac_sha256}; +use reqsign_core::time::{format_date, format_iso8601, now, DateTime}; +use reqsign_core::{Build, Context, SigningRequest}; use std::fmt::Write; use std::time::Duration; -use anyhow::Result; -use http::header; -use http::HeaderValue; -use log::debug; -use percent_encoding::percent_decode_str; -use percent_encoding::utf8_percent_encode; - -use super::constants::AWS_QUERY_ENCODE_SET; -use super::constants::X_AMZ_CONTENT_SHA_256; -use super::constants::X_AMZ_DATE; -use super::constants::X_AMZ_SECURITY_TOKEN; -use super::credential::Credential; -use reqsign_core::hash::hex_hmac_sha256; -use reqsign_core::hash::hex_sha256; -use reqsign_core::hash::hmac_sha256; -use reqsign_core::time::format_date; -use reqsign_core::time::format_iso8601; -use reqsign_core::time::now; -use reqsign_core::time::DateTime; -use reqsign_core::SigningMethod; -use reqsign_core::SigningRequest; - -/// Signer that implement AWS SigV4. +/// Builder that implement AWS SigV4. /// /// - [Signature Version 4 signing process](https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html) #[derive(Debug)] -pub struct Signer { +pub struct Builder { service: String, region: String, time: Option, } -impl Signer { - /// Create a builder. +impl Builder { + /// Create a new builder for AWS V4 signer. pub fn new(service: &str, region: &str) -> Self { Self { - service: service.to_string(), - region: region.to_string(), + service: service.into(), + region: region.into(), + time: None, } } @@ -54,26 +42,43 @@ impl Signer { /// We should always take current time to sign requests. /// Only use this function for testing. #[cfg(test)] - pub fn time(mut self, time: DateTime) -> Self { + pub fn with_time(mut self, time: DateTime) -> Self { self.time = Some(time); self } +} + +#[async_trait] +impl Build for Builder { + type Key = Credential; - fn build( + async fn build( &self, - req: &mut http::request::Parts, - method: SigningMethod, - cred: &Credential, - ) -> Result { + _: &Context, + req: &mut Parts, + key: Option<&Self::Key>, + expires_in: Option, + ) -> anyhow::Result<()> { let now = self.time.unwrap_or_else(now); - let mut ctx = SigningRequest::build(req)?; + let mut signed_req = SigningRequest::build(req)?; + + let Some(cred) = key else { + return Ok(()); + }; // canonicalize context - canonicalize_header(&mut ctx, method, cred, now)?; - canonicalize_query(&mut ctx, method, cred, now, &self.service, &self.region)?; + canonicalize_header(&mut signed_req, cred, expires_in, now)?; + canonicalize_query( + &mut signed_req, + cred, + expires_in, + now, + &self.service, + &self.region, + )?; // build canonical request and string to sign. - let creq = canonical_request_string(&mut ctx)?; + let creq = canonical_request_string(&mut signed_req)?; let encoded_req = hex_sha256(creq.as_bytes()); // Scope: "20220313///aws4_request" @@ -105,52 +110,29 @@ impl Signer { generate_signing_key(&cred.secret_access_key, now, &self.region, &self.service); let signature = hex_hmac_sha256(&signing_key, string_to_sign.as_bytes()); - match method { - SigningMethod::Header => { - let mut authorization = HeaderValue::from_str(&format!( - "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", - cred.access_key_id, - scope, - ctx.header_name_to_vec_sorted().join(";"), - signature - ))?; - authorization.set_sensitive(true); - - ctx.headers - .insert(http::header::AUTHORIZATION, authorization); - } - SigningMethod::Query(_) => { - ctx.query.push(("X-Amz-Signature".into(), signature)); - } + if expires_in.is_some() { + signed_req.query.push(("X-Amz-Signature".into(), signature)); + } else { + let mut authorization = HeaderValue::from_str(&format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + cred.access_key_id, + scope, + signed_req.header_name_to_vec_sorted().join(";"), + signature + ))?; + authorization.set_sensitive(true); + + signed_req + .headers + .insert(header::AUTHORIZATION, authorization); } - Ok(ctx) - } - - /// Get the region of this signer. - pub fn region(&self) -> &str { - &self.region - } - - /// Signing request with header. - pub fn sign(&self, parts: &mut http::request::Parts, cred: &Credential) -> Result<()> { - let ctx = self.build(parts, SigningMethod::Header, cred)?; - ctx.apply(parts) - } - - /// Signing request with query. - pub fn sign_query( - &self, - parts: &mut http::request::Parts, - expire: Duration, - cred: &Credential, - ) -> Result<()> { - let ctx = self.build(parts, SigningMethod::Query(expire), cred)?; - ctx.apply(parts) + // Apply to the request. + signed_req.apply(req) } } -fn canonical_request_string(ctx: &mut SigningRequest) -> Result { +fn canonical_request_string(ctx: &mut SigningRequest) -> anyhow::Result { // 256 is specially chosen to avoid reallocation for most requests. let mut f = String::with_capacity(256); @@ -198,10 +180,10 @@ fn canonical_request_string(ctx: &mut SigningRequest) -> Result { fn canonicalize_header( ctx: &mut SigningRequest, - method: SigningMethod, cred: &Credential, + expires_in: Option, now: DateTime, -) -> Result<()> { +) -> anyhow::Result<()> { // Header names and values need to be normalized according to Step 4 of https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html for (_, value) in ctx.headers.iter_mut() { SigningRequest::header_value_normalize(value) @@ -213,7 +195,7 @@ fn canonicalize_header( .insert(header::HOST, ctx.authority.as_str().parse()?); } - if method == SigningMethod::Header { + if expires_in.is_none() { // Insert DATE header if not present. if ctx.headers.get(X_AMZ_DATE).is_none() { let date_header = HeaderValue::try_from(format_iso8601(now))?; @@ -243,13 +225,13 @@ fn canonicalize_header( fn canonicalize_query( ctx: &mut SigningRequest, - method: SigningMethod, cred: &Credential, + expires_in: Option, now: DateTime, service: &str, region: &str, -) -> Result<()> { - if let SigningMethod::Query(expire) = method { +) -> anyhow::Result<()> { + if let Some(expire) = expires_in { ctx.query .push(("X-Amz-Algorithm".into(), "AWS4-HMAC-SHA256".into())); ctx.query.push(( @@ -317,6 +299,9 @@ fn generate_signing_key(secret: &str, time: DateTime, region: &str, service: &st mod tests { use std::time::SystemTime; + use super::*; + use crate::Config; + use crate::DefaultLoader; use anyhow::Result; use aws_credential_types::Credentials; use aws_sigv4::http_request::PayloadChecksumKind; @@ -328,15 +313,34 @@ mod tests { use aws_sigv4::sign::v4; use http::header; use http::Request; - use macro_rules_attribute::apply; - use reqwest::Client; - - use super::*; - use crate::Config; - use crate::DefaultLoader; + use reqsign_core::Load; + use reqsign_file_read_tokio::TokioFileRead; + use reqsign_http_send_reqwest::ReqwestHttpSend; + + /// (name, request_builder) + type TestCase = (&'static str, fn() -> Request<&'static str>); + + fn test_cases() -> Vec { + vec![ + ("get_request", test_get_request), + ("get_request_with_sse", test_get_request_with_sse), + ("get_request_with_query", test_get_request_with_query), + ("get_request_virtual_host", test_get_request_virtual_host), + ( + "get_request_with_query_virtual_host", + test_get_request_with_query_virtual_host, + ), + ("put_request", test_put_request), + ( + "put_request_with_body_digest", + test_put_request_with_body_digest, + ), + ("put_request_virtual_host", test_put_request_virtual_host), + ] + } - fn test_get_request() -> http::Request<&'static str> { - let mut req = http::Request::new(""); + fn test_get_request() -> Request<&'static str> { + let mut req = Request::new(""); *req.method_mut() = http::Method::GET; *req.uri_mut() = "http://127.0.0.1:9000/hello" .parse() @@ -345,8 +349,8 @@ mod tests { req } - fn test_get_request_with_sse() -> http::Request<&'static str> { - let mut req = http::Request::new(""); + fn test_get_request_with_sse() -> Request<&'static str> { + let mut req = Request::new(""); *req.method_mut() = http::Method::GET; *req.uri_mut() = "http://127.0.0.1:9000/hello" .parse() @@ -375,8 +379,8 @@ mod tests { req } - fn test_get_request_with_query() -> http::Request<&'static str> { - let mut req = http::Request::new(""); + fn test_get_request_with_query() -> Request<&'static str> { + let mut req = Request::new(""); *req.method_mut() = http::Method::GET; *req.uri_mut() = "http://127.0.0.1:9000/hello?list-type=2&max-keys=3&prefix=CI/&start-after=ExampleGuide.pdf" .parse() @@ -385,8 +389,8 @@ mod tests { req } - fn test_get_request_virtual_host() -> http::Request<&'static str> { - let mut req = http::Request::new(""); + fn test_get_request_virtual_host() -> Request<&'static str> { + let mut req = Request::new(""); *req.method_mut() = http::Method::GET; *req.uri_mut() = "http://hello.s3.test.example.com" .parse() @@ -395,8 +399,8 @@ mod tests { req } - fn test_get_request_with_query_virtual_host() -> http::Request<&'static str> { - let mut req = http::Request::new(""); + fn test_get_request_with_query_virtual_host() -> Request<&'static str> { + let mut req = Request::new(""); *req.method_mut() = http::Method::GET; *req.uri_mut() = "http://hello.s3.test.example.com?list-type=2&max-keys=3&prefix=CI/&start-after=ExampleGuide.pdf" .parse() @@ -405,9 +409,9 @@ mod tests { req } - fn test_put_request() -> http::Request<&'static str> { + fn test_put_request() -> Request<&'static str> { let content = "Hello,World!"; - let mut req = http::Request::new(content); + let mut req = Request::new(content); *req.method_mut() = http::Method::PUT; *req.uri_mut() = "http://127.0.0.1:9000/hello" .parse() @@ -421,9 +425,9 @@ mod tests { req } - fn test_put_request_with_body_digest() -> http::Request<&'static str> { + fn test_put_request_with_body_digest() -> Request<&'static str> { let content = "Hello,World!"; - let mut req = http::Request::new(content); + let mut req = Request::new(content); *req.method_mut() = http::Method::PUT; *req.uri_mut() = "http://127.0.0.1:9000/hello" .parse() @@ -443,9 +447,9 @@ mod tests { req } - fn test_put_request_virtual_host() -> http::Request<&'static str> { + fn test_put_request_virtual_host() -> Request<&'static str> { let content = "Hello,World!"; - let mut req = http::Request::new(content); + let mut req = Request::new(content); *req.method_mut() = http::Method::PUT; *req.uri_mut() = "http://hello.s3.test.example.com" .parse() @@ -459,22 +463,9 @@ mod tests { req } - macro_rules! test_cases { - ($($tt:tt)*) => { - #[test_case::test_case(test_get_request)] - #[test_case::test_case(test_get_request_with_sse)] - #[test_case::test_case(test_get_request_with_query)] - #[test_case::test_case(test_get_request_virtual_host)] - #[test_case::test_case(test_get_request_with_query_virtual_host)] - #[test_case::test_case(test_put_request)] - #[test_case::test_case(test_put_request_virtual_host)] - #[test_case::test_case(test_put_request_with_body_digest)] - $($tt)* - }; - } - - fn compare_request(name: &str, l: &http::Request<&str>, r: &http::Request<&str>) { - fn format_headers(req: &http::Request<&str>) -> Vec { + #[track_caller] + fn compare_request(name: &str, l: &Request<&str>, r: &Request<&str>) { + fn format_headers(req: &Request<&str>) -> Vec { let mut hs = req .headers() .iter() @@ -496,7 +487,7 @@ mod tests { "{name} header mismatch" ); - fn format_query(req: &http::Request<&str>) -> Vec { + fn format_query(req: &Request<&str>) -> Vec { let query = req.uri().query().unwrap_or_default(); let mut query = form_urlencoded::parse(query.as_bytes()) .map(|(k, v)| format!("{}={}", &k, &v)) @@ -508,9 +499,28 @@ mod tests { assert_eq!(format_query(l), format_query(r), "{name} query mismatch"); } - #[apply(test_cases)] #[tokio::test] - async fn test_calculate(req_fn: fn() -> http::Request<&'static str>) -> Result<()> { + async fn test() -> Result<()> { + for (name, req) in test_cases() { + calculate(req) + .await + .unwrap_or_else(|err| panic!("calculate {name} should pass: {err:?}")); + calculate_in_query(req) + .await + .unwrap_or_else(|err| panic!("calculate_in_query {name} should pass: {err:?}")); + test_calculate_with_token(req).await.unwrap_or_else(|err| { + panic!("test_calculate_with_token {name} should pass: {err:?}") + }); + test_calculate_with_token_in_query(req) + .await + .unwrap_or_else(|err| { + panic!("test_calculate_with_token_in_query {name} should pass: {err:?}") + }); + } + Ok(()) + } + + async fn calculate(req_fn: fn() -> Request<&'static str>) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); let mut req = req_fn(); @@ -558,8 +568,7 @@ mod tests { ) .unwrap(), &sp.into(), - ) - .expect("signing must succeed"); + )?; let (aws_sig, _) = output.into_parts(); aws_sig.apply_to_request_http1x(&mut req); let expected_req = req; @@ -567,29 +576,31 @@ mod tests { let req = req_fn(); let (mut parts, body) = req.into_parts(); + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); let loader = DefaultLoader::new( - Client::new(), Config { access_key_id: Some("access_key_id".to_string()), secret_access_key: Some("secret_access_key".to_string()), ..Default::default() - }, + } + .into(), ); - let cred = loader.load().await?.unwrap(); + let cred = loader.load(&ctx).await?.unwrap(); - let signer = Signer::new("s3", "test").time(now); - signer.sign(&mut parts, &cred).expect("must apply success"); + let builder = Builder::new("s3", "test").with_time(now); + builder + .build(&ctx, &mut parts, Some(&cred), None) + .await + .expect("must apply success"); - let actual_req = http::request::Request::from_parts(parts, body); + let actual_req = Request::from_parts(parts, body); compare_request(&name, &expected_req, &actual_req); Ok(()) } - #[apply(test_cases)] - #[tokio::test] - async fn test_calculate_in_query(req_fn: fn() -> http::Request<&'static str>) -> Result<()> { + async fn calculate_in_query(req_fn: fn() -> Request<&'static str>) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); let mut req = req_fn(); @@ -648,19 +659,27 @@ mod tests { let req = req_fn(); let (mut parts, body) = req.into_parts(); + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); let loader = DefaultLoader::new( - Client::new(), Config { access_key_id: Some("access_key_id".to_string()), secret_access_key: Some("secret_access_key".to_string()), ..Default::default() - }, + } + .into(), ); - let cred = loader.load().await?.unwrap(); + let cred = loader.load(&ctx).await?.unwrap(); - let signer = Signer::new("s3", "test").time(now); + let builder = Builder::new("s3", "test").with_time(now); - signer.sign_query(&mut parts, Duration::from_secs(3600), &cred)?; + builder + .build( + &ctx, + &mut parts, + Some(&cred), + Some(Duration::from_secs(3600)), + ) + .await?; let actual_req = Request::from_parts(parts, body); compare_request(&name, &expected_req, &actual_req); @@ -668,9 +687,7 @@ mod tests { Ok(()) } - #[apply(test_cases)] - #[tokio::test] - async fn test_calculate_with_token(req_fn: fn() -> http::Request<&'static str>) -> Result<()> { + async fn test_calculate_with_token(req_fn: fn() -> Request<&'static str>) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); let mut req = req_fn(); @@ -727,20 +744,23 @@ mod tests { let req = req_fn(); let (mut parts, body) = req.into_parts(); + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); let loader = DefaultLoader::new( - Client::new(), Config { access_key_id: Some("access_key_id".to_string()), secret_access_key: Some("secret_access_key".to_string()), session_token: Some("security_token".to_string()), ..Default::default() - }, + } + .into(), ); - let cred = loader.load().await?.unwrap(); - - let signer = Signer::new("s3", "test").time(now); + let cred = loader.load(&ctx).await?.unwrap(); - signer.sign(&mut parts, &cred).expect("must apply success"); + let builder = Builder::new("s3", "test").with_time(now); + builder + .build(&ctx, &mut parts, Some(&cred), None) + .await + .expect("must apply success"); let actual_req = Request::from_parts(parts, body); compare_request(&name, &expected_req, &actual_req); @@ -748,10 +768,8 @@ mod tests { Ok(()) } - #[apply(test_cases)] - #[tokio::test] async fn test_calculate_with_token_in_query( - req_fn: fn() -> http::Request<&'static str>, + req_fn: fn() -> Request<&'static str>, ) -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); @@ -812,20 +830,27 @@ mod tests { let req = req_fn(); let (mut parts, body) = req.into_parts(); + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); let loader = DefaultLoader::new( - Client::new(), Config { access_key_id: Some("access_key_id".to_string()), secret_access_key: Some("secret_access_key".to_string()), session_token: Some("security_token".to_string()), ..Default::default() - }, + } + .into(), ); - let cred = loader.load().await?.unwrap(); - - let signer = Signer::new("s3", "test").time(now); - signer - .sign_query(&mut parts, Duration::from_secs(3600), &cred) + let cred = loader.load(&ctx).await?.unwrap(); + + let builder = Builder::new("s3", "test").with_time(now); + builder + .build( + &ctx, + &mut parts, + Some(&cred), + Some(Duration::from_secs(3600)), + ) + .await .expect("must apply success"); let actual_req = Request::from_parts(parts, body); diff --git a/services/aws-v4/src/config.rs b/services/aws-v4/src/config.rs index 9d6d2ca..031e3ad 100644 --- a/services/aws-v4/src/config.rs +++ b/services/aws-v4/src/config.rs @@ -1,3 +1,5 @@ +use std::fmt; + use super::constants::*; #[cfg(not(target_arch = "wasm32"))] use anyhow::anyhow; @@ -7,11 +9,11 @@ use anyhow::Result; use ini::Ini; #[cfg(not(target_arch = "wasm32"))] use log::debug; +use reqsign_core::utils::Redact; use reqsign_core::Context; /// Config for aws services. #[derive(Clone)] -#[cfg_attr(test, derive(Debug))] pub struct Config { /// `config_file` will be load from: /// @@ -128,6 +130,29 @@ impl Default for Config { } } +impl fmt::Debug for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Config") + .field("config_file", &self.config_file) + .field("shared_credentials_file", &self.shared_credentials_file) + .field("profile", &self.profile) + .field("region", &self.region) + .field("sts_regional_endpoints", &self.sts_regional_endpoints) + .field("access_key_id", &Redact::from(&self.access_key_id)) + .field("secret_access_key", &Redact::from(&self.secret_access_key)) + .field("session_token", &Redact::from(&self.session_token)) + .field("role_arn", &self.role_arn) + .field("role_session_name", &self.role_session_name) + .field("duration_seconds", &self.duration_seconds) + .field("external_id", &Redact::from(&self.external_id)) + .field("tags", &self.tags) + .field("web_identity_token_file", &self.web_identity_token_file) + .field("ec2_metadata_disabled", &self.ec2_metadata_disabled) + .field("endpoint_url", &self.endpoint_url) + .finish() + } +} + impl Config { /// Load config from env. pub fn from_env(mut self, ctx: &Context) -> Self { diff --git a/services/aws-v4/src/credential.rs b/services/aws-v4/src/credential.rs deleted file mode 100644 index 5485711..0000000 --- a/services/aws-v4/src/credential.rs +++ /dev/null @@ -1,1024 +0,0 @@ -use std::fmt::Debug; -use std::fmt::Write; -use std::fs; -use std::sync::Arc; -use std::sync::Mutex; - -use anyhow::anyhow; -use anyhow::Result; -use async_trait::async_trait; -use http::header::CONTENT_LENGTH; -use log::debug; -use quick_xml::de; -use reqwest::Client; -use serde::Deserialize; - -use super::config::Config; -use super::constants::X_AMZ_CONTENT_SHA_256; -use crate::Signer; -use reqsign_core::time::now; -use reqsign_core::time::parse_rfc3339; -use reqsign_core::time::DateTime; - -pub const EMPTY_STRING_SHA256: &str = - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; - -/// Credential that holds the access_key and secret_key. -#[derive(Default, Clone)] -#[cfg_attr(test, derive(Debug))] -pub struct Credential { - /// Access key id for aws services. - pub access_key_id: String, - /// Secret access key for aws services. - pub secret_access_key: String, - /// Session token for aws services. - pub session_token: Option, - /// Expiration time for this credential. - pub expires_in: Option, -} - -impl Credential { - /// is current cred is valid? - pub fn is_valid(&self) -> bool { - if (self.access_key_id.is_empty() || self.secret_access_key.is_empty()) - && self.session_token.is_none() - { - return false; - } - // Take 120s as buffer to avoid edge cases. - if let Some(valid) = self - .expires_in - .map(|v| v > now() + chrono::TimeDelta::try_minutes(2).expect("in bounds")) - { - return valid; - } - - true - } -} - -/// Loader trait will try to load credential from different sources. -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -pub trait CredentialLoad: 'static + Send + Sync { - /// Load credential from sources. - /// - /// - If succeed, return `Ok(Some(cred))` - /// - If not found, return `Ok(None)` - /// - If unexpected errors happened, return `Err(err)` - async fn load_credential(&self, client: Client) -> Result>; -} - -/// CredentialLoader will load credential from different methods. -pub struct DefaultLoader { - client: Client, - config: Config, - credential: Arc>>, - imds_v2_loader: Option, -} - -impl DefaultLoader { - /// Create a new CredentialLoader - pub fn new(client: Client, config: Config) -> Self { - let imds_v2_loader = if config.ec2_metadata_disabled { - None - } else { - Some(IMDSv2Loader::new(client.clone())) - }; - Self { - client, - config, - credential: Arc::default(), - imds_v2_loader, - } - } - - /// Disable load from ec2 metadata. - pub fn with_disable_ec2_metadata(mut self) -> Self { - self.imds_v2_loader = None; - self - } - - /// Load credential. - /// - /// Resolution order: - /// 1. Environment variables - /// 2. Shared config (`~/.aws/config`, `~/.aws/credentials`) - /// 3. Web Identity Tokens - /// 4. ECS (IAM Roles for Tasks) & General HTTP credentials: - /// 5. EC2 IMDSv2 - pub async fn load(&self) -> Result> { - // Return cached credential if it has been loaded at least once. - match self.credential.lock().expect("lock poisoned").clone() { - Some(cred) if cred.is_valid() => return Ok(Some(cred)), - _ => (), - } - - let cred = self.load_inner().await?; - - let mut lock = self.credential.lock().expect("lock poisoned"); - lock.clone_from(&cred); - - Ok(cred) - } - - async fn load_inner(&self) -> Result> { - if let Some(cred) = self.load_via_config().map_err(|err| { - debug!("load credential via config failed: {err:?}"); - err - })? { - return Ok(Some(cred)); - } - - if let Some(cred) = self - .load_via_assume_role_with_web_identity() - .await - .map_err(|err| { - debug!("load credential via assume_role_with_web_identity failed: {err:?}"); - err - })? - { - return Ok(Some(cred)); - } - - if let Some(cred) = self.load_via_imds_v2().await.map_err(|err| { - debug!("load credential via imds_v2 failed: {err:?}"); - err - })? { - return Ok(Some(cred)); - } - - Ok(None) - } - - fn load_via_config(&self) -> Result> { - if let (Some(ak), Some(sk)) = (&self.config.access_key_id, &self.config.secret_access_key) { - Ok(Some(Credential { - access_key_id: ak.clone(), - secret_access_key: sk.clone(), - session_token: self.config.session_token.clone(), - // Set expires_in to 10 minutes to enforce re-read - // from file. - expires_in: Some(now() + chrono::TimeDelta::try_minutes(10).expect("in bounds")), - })) - } else { - Ok(None) - } - } - - async fn load_via_imds_v2(&self) -> Result> { - let loader = match &self.imds_v2_loader { - Some(loader) => loader, - None => return Ok(None), - }; - - loader.load().await - } - - async fn load_via_assume_role_with_web_identity(&self) -> Result> { - let (token_file, role_arn) = - match (&self.config.web_identity_token_file, &self.config.role_arn) { - (Some(token_file), Some(role_arn)) => (token_file, role_arn), - _ => return Ok(None), - }; - - let token = fs::read_to_string(token_file)?; - let role_session_name = &self.config.role_session_name; - - let endpoint = self.sts_endpoint()?; - - // Construct request to AWS STS Service. - let url = format!("https://{endpoint}/?Action=AssumeRoleWithWebIdentity&RoleArn={role_arn}&WebIdentityToken={token}&Version=2011-06-15&RoleSessionName={role_session_name}"); - let req = self.client.get(&url).header( - http::header::CONTENT_TYPE.as_str(), - "application/x-www-form-urlencoded", - ); - - let resp = req.send().await?; - if resp.status() != http::StatusCode::OK { - let content = resp.text().await?; - return Err(anyhow!("request to AWS STS Services failed: {content}")); - } - - let resp: AssumeRoleWithWebIdentityResponse = de::from_str(&resp.text().await?)?; - let resp_cred = resp.result.credentials; - - let cred = Credential { - access_key_id: resp_cred.access_key_id, - secret_access_key: resp_cred.secret_access_key, - session_token: Some(resp_cred.session_token), - expires_in: Some(parse_rfc3339(&resp_cred.expiration)?), - }; - - Ok(Some(cred)) - } - - /// Get the sts endpoint. - /// - /// The returning format may look like `sts.{region}.amazonaws.com` - /// - /// # Notes - /// - /// AWS could have different sts endpoint based on it's region. - /// We can check them by region name. - /// - /// ref: https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs - fn sts_endpoint(&self) -> Result { - // use regional sts if sts_regional_endpoints has been set. - if self.config.sts_regional_endpoints == "regional" { - let region = self.config.region.clone().ok_or_else(|| { - anyhow!("sts_regional_endpoints set to reginal, but region is not set") - })?; - if region.starts_with("cn-") { - Ok(format!("sts.{region}.amazonaws.com.cn")) - } else { - Ok(format!("sts.{region}.amazonaws.com")) - } - } else { - let region = self.config.region.clone().unwrap_or_default(); - if region.starts_with("cn") { - // TODO: seems aws china doesn't support global sts? - Ok("sts.amazonaws.com.cn".to_string()) - } else { - Ok("sts.amazonaws.com".to_string()) - } - } - } -} - -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -impl CredentialLoad for DefaultLoader { - async fn load_credential(&self, _: Client) -> Result> { - self.load().await - } -} - -pub struct IMDSv2Loader { - client: Client, - - token: Arc>, -} - -impl IMDSv2Loader { - /// Create a new IMDSv2Loader. - pub fn new(client: Client) -> Self { - Self { - client, - token: Arc::new(Mutex::new(("".to_string(), DateTime::MIN_UTC))), - } - } - - pub async fn load(&self) -> Result> { - let token = self.load_ec2_metadata_token().await?; - - // List all credentials that node has. - let url = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"; - let req = self - .client - .get(url) - .header("x-aws-ec2-metadata-token", &token); - let resp = req.send().await?; - if resp.status() != http::StatusCode::OK { - let content = resp.text().await?; - return Err(anyhow!( - "request to AWS EC2 Metadata Services failed: {content}" - )); - } - let profile_name = resp.text().await?; - - // Get the credentials via role_name. - let url = format!( - "http://169.254.169.254/latest/meta-data/iam/security-credentials/{profile_name}" - ); - let req = self - .client - .get(&url) - .header("x-aws-ec2-metadata-token", &token); - let resp = req.send().await?; - if resp.status() != http::StatusCode::OK { - let content = resp.text().await?; - return Err(anyhow!( - "request to AWS EC2 Metadata Services failed: {content}" - )); - } - - let content = resp.text().await?; - let resp: Ec2MetadataIamSecurityCredentials = serde_json::from_str(&content)?; - if resp.code != "Success" { - return Err(anyhow!( - "request to AWS EC2 Metadata Services failed: {content}" - )); - } - - let cred = Credential { - access_key_id: resp.access_key_id, - secret_access_key: resp.secret_access_key, - session_token: Some(resp.token), - expires_in: Some(parse_rfc3339(&resp.expiration)?), - }; - - Ok(Some(cred)) - } - - /// load_ec2_metadata_token will load ec2 metadata token from IMDS. - /// - /// Return value is (token, expires_in). - async fn load_ec2_metadata_token(&self) -> Result { - { - let (token, expires_in) = self.token.lock().expect("lock poisoned").clone(); - if expires_in > now() { - return Ok(token); - } - } - - let url = "http://169.254.169.254/latest/api/token"; - #[allow(unused_mut)] - let mut req = self - .client - .put(url) - .header(CONTENT_LENGTH, "0") - // 21600s (6h) is recommended by AWS. - .header("x-aws-ec2-metadata-token-ttl-seconds", "21600"); - - // Set timeout to 1s to avoid hanging on non-s3 env. - #[cfg(not(target_arch = "wasm32"))] - { - req = req.timeout(std::time::Duration::from_secs(1)); - } - - let resp = req.send().await?; - if resp.status() != http::StatusCode::OK { - let content = resp.text().await?; - return Err(anyhow!( - "request to AWS EC2 Metadata Services failed: {content}" - )); - } - let ec2_token = resp.text().await?; - // Set expires_in to 10 minutes to enforce re-read. - let expires_in = now() + chrono::TimeDelta::try_seconds(21600).expect("in bounds") - - chrono::TimeDelta::try_seconds(600).expect("in bounds"); - - { - *self.token.lock().expect("lock poisoned") = (ec2_token.clone(), expires_in); - } - - Ok(ec2_token) - } -} - -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -impl CredentialLoad for IMDSv2Loader { - async fn load_credential(&self, _: Client) -> Result> { - self.load().await - } -} - -/// AssumeRoleLoader will load credential via assume role. -pub struct AssumeRoleLoader { - client: Client, - config: Config, - - source_credential: Box, - sts_signer: Signer, -} - -impl AssumeRoleLoader { - /// Create a new assume role loader. - pub fn new( - client: Client, - config: Config, - source_credential: Box, - ) -> Result { - let region = config.region.clone().ok_or_else(|| { - anyhow!("assume role loader requires region, but not found, please check your configuration") - })?; - - Ok(Self { - client, - config, - source_credential, - - sts_signer: Signer::new("sts", ®ion), - }) - } - - /// Load credential via assume role. - pub async fn load(&self) -> Result> { - let role_arn =self.config.role_arn.clone().ok_or_else(|| { - anyhow!("assume role loader requires role_arn, but not found, please check your configuration") - })?; - - let role_session_name = &self.config.role_session_name; - - let endpoint = self.sts_endpoint()?; - - // Construct request to AWS STS Service. - let mut url = format!("https://{endpoint}/?Action=AssumeRole&RoleArn={role_arn}&Version=2011-06-15&RoleSessionName={role_session_name}"); - if let Some(external_id) = &self.config.external_id { - write!(url, "&ExternalId={external_id}")?; - } - if let Some(duration_seconds) = &self.config.duration_seconds { - write!(url, "&DurationSeconds={duration_seconds}")?; - } - if let Some(tags) = &self.config.tags { - for (idx, (key, value)) in tags.iter().enumerate() { - let tag_index = idx + 1; - write!( - url, - "&Tags.member.{tag_index}.Key={key}&Tags.member.{tag_index}.Value={value}" - )?; - } - } - - let req = http::request::Request::builder() - .method("GET") - .uri(url) - .header( - http::header::CONTENT_TYPE.as_str(), - "application/x-www-form-urlencoded", - ) - // Set content sha to empty string. - .header(X_AMZ_CONTENT_SHA_256, EMPTY_STRING_SHA256) - .body(reqwest::Body::from(""))?; - - let source_cred = self - .source_credential - .load_credential(self.client.clone()) - .await? - .ok_or_else(|| { - anyhow!("source credential is required for AssumeRole, but not found, please check your configuration") - })?; - - let (mut parts, body) = req.into_parts(); - self.sts_signer.sign(&mut parts, &source_cred)?; - let req = http::Request::from_parts(parts, body) - .try_into() - .map_err(|_| anyhow!("failed to convert http::Request to reqwest::Request"))?; - - let resp = self.client.execute(req).await?; - if resp.status() != http::StatusCode::OK { - let content = resp.text().await?; - return Err(anyhow!("request to AWS STS Services failed: {content}")); - } - - let resp: AssumeRoleResponse = de::from_str(&resp.text().await?)?; - let resp_cred = resp.result.credentials; - - let cred = Credential { - access_key_id: resp_cred.access_key_id, - secret_access_key: resp_cred.secret_access_key, - session_token: Some(resp_cred.session_token), - expires_in: Some(parse_rfc3339(&resp_cred.expiration)?), - }; - - Ok(Some(cred)) - } - - /// Get the sts endpoint. - /// - /// The returning format may look like `sts.{region}.amazonaws.com` - /// - /// # Notes - /// - /// AWS could have different sts endpoint based on it's region. - /// We can check them by region name. - /// - /// ref: https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs - fn sts_endpoint(&self) -> Result { - // use regional sts if sts_regional_endpoints has been set. - if self.config.sts_regional_endpoints == "regional" { - let region = self.config.region.clone().ok_or_else(|| { - anyhow!("sts_regional_endpoints set to reginal, but region is not set") - })?; - if region.starts_with("cn-") { - Ok(format!("sts.{region}.amazonaws.com.cn")) - } else { - Ok(format!("sts.{region}.amazonaws.com")) - } - } else { - let region = self.config.region.clone().unwrap_or_default(); - if region.starts_with("cn") { - // TODO: seems aws china doesn't support global sts? - Ok("sts.amazonaws.com.cn".to_string()) - } else { - Ok("sts.amazonaws.com".to_string()) - } - } - } -} - -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -impl CredentialLoad for AssumeRoleLoader { - async fn load_credential(&self, _: Client) -> Result> { - self.load().await - } -} - -#[derive(Default, Debug, Deserialize)] -#[serde(default, rename_all = "PascalCase")] -struct AssumeRoleWithWebIdentityResponse { - #[serde(rename = "AssumeRoleWithWebIdentityResult")] - result: AssumeRoleWithWebIdentityResult, -} - -#[derive(Default, Debug, Deserialize)] -#[serde(default, rename_all = "PascalCase")] -struct AssumeRoleWithWebIdentityResult { - credentials: AssumeRoleWithWebIdentityCredentials, -} - -#[derive(Default, Debug, Deserialize)] -#[serde(default, rename_all = "PascalCase")] -struct AssumeRoleWithWebIdentityCredentials { - access_key_id: String, - secret_access_key: String, - session_token: String, - expiration: String, -} - -#[derive(Default, Debug, Deserialize)] -#[serde(default, rename_all = "PascalCase")] -struct AssumeRoleResponse { - #[serde(rename = "AssumeRoleResult")] - result: AssumeRoleResult, -} - -#[derive(Default, Debug, Deserialize)] -#[serde(default, rename_all = "PascalCase")] -struct AssumeRoleResult { - credentials: AssumeRoleCredentials, -} - -#[derive(Default, Debug, Deserialize)] -#[serde(default, rename_all = "PascalCase")] -struct AssumeRoleCredentials { - access_key_id: String, - secret_access_key: String, - session_token: String, - expiration: String, -} - -#[derive(Default, Debug, Deserialize)] -#[serde(default, rename_all = "PascalCase")] -struct Ec2MetadataIamSecurityCredentials { - access_key_id: String, - secret_access_key: String, - token: String, - expiration: String, - - code: String, -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - use std::env; - use std::str::FromStr; - use std::vec; - - use super::*; - use crate::constants::*; - use crate::signer::Signer; - use anyhow::Result; - use http::Request; - use http::StatusCode; - use once_cell::sync::Lazy; - use quick_xml::de; - use reqsign_core::{Context, StaticEnv}; - use reqsign_file_read_tokio::TokioFileRead; - use reqsign_http_send_reqwest::ReqwestHttpSend; - use reqwest::Client; - use tokio::runtime::Runtime; - - static RUNTIME: Lazy = Lazy::new(|| { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .expect("Should create a tokio runtime") - }); - - #[test] - fn test_credential_env_loader_without_env() { - let _ = env_logger::builder().is_test(true).try_init(); - - temp_env::with_vars_unset(vec![AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY], || { - RUNTIME.block_on(async { - let l = DefaultLoader::new(reqwest::Client::new(), Config::default()) - .with_disable_ec2_metadata(); - let x = l.load().await.expect("load must succeed"); - assert!(x.is_none()); - }) - }); - } - - #[tokio::test] - async fn test_credential_env_loader_with_env() { - let _ = env_logger::builder().is_test(true).try_init(); - - let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); - let context = context.with_env(StaticEnv { - home_dir: None, - envs: HashMap::from_iter([ - (AWS_ACCESS_KEY_ID.to_string(), "access_key_id".to_string()), - ( - AWS_SECRET_ACCESS_KEY.to_string(), - "secret_access_key".to_string(), - ), - ]), - }); - - let l = DefaultLoader::new(Client::new(), Config::default().from_env(&context)); - let x = l.load().await.expect("load must succeed"); - - let x = x.expect("must load succeed"); - assert_eq!("access_key_id", x.access_key_id); - assert_eq!("secret_access_key", x.secret_access_key); - } - - #[tokio::test] - async fn test_credential_profile_loader_from_config() { - let _ = env_logger::builder().is_test(true).try_init(); - - let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); - let context = context.with_env(StaticEnv { - home_dir: None, - envs: HashMap::from_iter([ - ( - AWS_CONFIG_FILE.to_string(), - format!( - "{}/testdata/default_config", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ), - ), - ( - AWS_SHARED_CREDENTIALS_FILE.to_string(), - format!( - "{}/testdata/not_exist", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ), - ), - ]), - }); - - let l = DefaultLoader::new( - Client::new(), - Config::default() - .from_env(&context) - .from_profile(&context) - .await, - ); - let x = l.load().await.unwrap().unwrap(); - assert_eq!("config_access_key_id", x.access_key_id); - assert_eq!("config_secret_access_key", x.secret_access_key); - } - - #[tokio::test] - async fn test_credential_profile_loader_from_shared() { - let _ = env_logger::builder().is_test(true).try_init(); - - let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); - let context = context.with_env(StaticEnv { - home_dir: None, - envs: HashMap::from_iter([ - ( - AWS_CONFIG_FILE.to_string(), - format!( - "{}/testdata/not_exist", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ), - ), - ( - AWS_SHARED_CREDENTIALS_FILE.to_string(), - format!( - "{}/testdata/default_credential", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ), - ), - ]), - }); - - let l = DefaultLoader::new( - Client::new(), - Config::default() - .from_env(&context) - .from_profile(&context) - .await, - ); - let x = l.load().await.unwrap().unwrap(); - assert_eq!("shared_access_key_id", x.access_key_id); - assert_eq!("shared_secret_access_key", x.secret_access_key); - } - - /// AWS_SHARED_CREDENTIALS_FILE should be taken first. - #[tokio::test] - async fn test_credential_profile_loader_from_both() { - let _ = env_logger::builder().is_test(true).try_init(); - - let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); - let context = context.with_env(StaticEnv { - home_dir: None, - envs: HashMap::from_iter([ - ( - AWS_CONFIG_FILE.to_string(), - format!( - "{}/testdata/default_config", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ), - ), - ( - AWS_SHARED_CREDENTIALS_FILE.to_string(), - format!( - "{}/testdata/default_credential", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ), - ), - ]), - }); - - let l = DefaultLoader::new( - Client::new(), - Config::default() - .from_env(&context) - .from_profile(&context) - .await, - ); - let x = l.load().await.expect("load must success").unwrap(); - assert_eq!("shared_access_key_id", x.access_key_id); - assert_eq!("shared_secret_access_key", x.secret_access_key); - } - - #[tokio::test] - async fn test_signer_with_web_loader() -> Result<()> { - let _ = env_logger::builder().is_test(true).try_init(); - - dotenv::from_filename("../../../.env").ok(); - - if env::var("REQSIGN_AWS_S3_TEST").is_err() - || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on" - { - return Ok(()); - } - - // Ignore test if role_arn not set - let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") { - v - } else { - return Ok(()); - }; - - // let provider_arn = env::var("REQSIGN_AWS_PROVIDER_ARN").expect("REQSIGN_AWS_PROVIDER_ARN not exist"); - let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist"); - - let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist"); - let file_path = format!( - "{}/testdata/web_identity_token_file", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ); - fs::write(&file_path, github_token)?; - - let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); - let context = context.with_env(StaticEnv { - home_dir: None, - envs: HashMap::from_iter([ - (AWS_REGION.to_string(), region.to_string()), - (AWS_ROLE_ARN.to_string(), role_arn.to_string()), - ( - AWS_WEB_IDENTITY_TOKEN_FILE.to_string(), - file_path.to_string(), - ), - ]), - }); - - let config = Config::default().from_env(&context); - let loader = DefaultLoader::new(reqwest::Client::new(), config); - - let signer = Signer::new("s3", ®ion); - - let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); - let mut req = Request::new(""); - *req.method_mut() = http::Method::GET; - *req.uri_mut() = - http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); - - let cred = loader - .load() - .await - .expect("credential must be valid") - .unwrap(); - - let (mut req, body) = req.into_parts(); - signer.sign(&mut req, &cred).expect("sign must success"); - let req = Request::from_parts(req, body); - - debug!("signed request url: {:?}", req.uri().to_string()); - debug!("signed request: {:?}", req); - - let client = Client::new(); - let resp = client.execute(req.try_into().unwrap()).await.unwrap(); - - let status = resp.status(); - debug!("got response: {:?}", resp); - debug!("got response content: {:?}", resp.text().await.unwrap()); - assert_eq!(status, StatusCode::NOT_FOUND); - Ok(()) - } - - #[tokio::test] - async fn test_signer_with_web_loader_assume_role() -> Result<()> { - let _ = env_logger::builder().is_test(true).try_init(); - - dotenv::from_filename("../../../.env").ok(); - - if env::var("REQSIGN_AWS_S3_TEST").is_err() - || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on" - { - return Ok(()); - } - - // Ignore test if role_arn not set - let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ROLE_ARN") { - v - } else { - return Ok(()); - }; - // Ignore test if assume_role_arn not set - let assume_role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") { - v - } else { - return Ok(()); - }; - - let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist"); - - let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist"); - let file_path = format!( - "{}/testdata/web_identity_token_file", - env::current_dir() - .expect("current_dir must exist") - .to_string_lossy() - ); - fs::write(&file_path, github_token)?; - - let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); - let context = context.with_env(StaticEnv { - home_dir: None, - envs: HashMap::from_iter([ - (AWS_REGION.to_string(), region.to_string()), - (AWS_ROLE_ARN.to_string(), role_arn.to_string()), - ( - AWS_WEB_IDENTITY_TOKEN_FILE.to_string(), - file_path.to_string(), - ), - ]), - }); - - let client = reqwest::Client::new(); - let default_loader = - DefaultLoader::new(client.clone(), Config::default().from_env(&context)) - .with_disable_ec2_metadata(); - - let cfg = Config { - role_arn: Some(assume_role_arn.clone()), - region: Some(region.clone()), - sts_regional_endpoints: "regional".to_string(), - ..Default::default() - }; - let loader = AssumeRoleLoader::new(client.clone(), cfg, Box::new(default_loader)) - .expect("AssumeRoleLoader must be valid"); - - let signer = Signer::new("s3", ®ion); - let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); - let mut req = Request::new(""); - *req.method_mut() = http::Method::GET; - *req.uri_mut() = - http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); - let cred = loader - .load() - .await - .expect("credential must be valid") - .unwrap(); - - let (mut parts, body) = req.into_parts(); - signer.sign(&mut parts, &cred).expect("sign must success"); - let req = Request::from_parts(parts, body); - - debug!("signed request url: {:?}", req.uri().to_string()); - debug!("signed request: {:?}", req); - let client = Client::new(); - let resp = client.execute(req.try_into().unwrap()).await.unwrap(); - let status = resp.status(); - debug!("got response: {:?}", resp); - debug!("got response content: {:?}", resp.text().await.unwrap()); - assert_eq!(status, StatusCode::NOT_FOUND); - Ok(()) - } - - #[test] - fn test_parse_assume_role_with_web_identity_response() -> Result<()> { - let _ = env_logger::builder().is_test(true).try_init(); - - let content = r#" - - test_audience - - role_id:reqsign - arn:aws:sts::123:assumed-role/reqsign/reqsign - - arn:aws:iam::123:oidc-provider/example.com/ - - access_key_id - secret_access_key - session_token - 2022-05-25T11:45:17Z - - subject - - - b1663ad1-23ab-45e9-b465-9af30b202eba - -"#; - - let resp: AssumeRoleWithWebIdentityResponse = - de::from_str(content).expect("xml deserialize must success"); - - assert_eq!(&resp.result.credentials.access_key_id, "access_key_id"); - assert_eq!( - &resp.result.credentials.secret_access_key, - "secret_access_key" - ); - assert_eq!(&resp.result.credentials.session_token, "session_token"); - assert_eq!(&resp.result.credentials.expiration, "2022-05-25T11:45:17Z"); - - Ok(()) - } - - #[test] - fn test_parse_assume_role_response() -> Result<()> { - let _ = env_logger::builder().is_test(true).try_init(); - - let content = r#" - - Alice - - arn:aws:sts::123456789012:assumed-role/demo/TestAR - ARO123EXAMPLE123:TestAR - - - ASIAIOSFODNN7EXAMPLE - wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY - - AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW - LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd - QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU - 9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz - +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA== - - 2019-11-09T13:34:41Z - - 6 - - - c6104cbe-af31-11e0-8154-cbc7ccf896c7 - -"#; - - let resp: AssumeRoleResponse = de::from_str(content).expect("xml deserialize must success"); - - assert_eq!( - &resp.result.credentials.access_key_id, - "ASIAIOSFODNN7EXAMPLE" - ); - assert_eq!( - &resp.result.credentials.secret_access_key, - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" - ); - assert_eq!( - &resp.result.credentials.session_token, - "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW - LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd - QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU - 9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz - +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==" - ); - assert_eq!(&resp.result.credentials.expiration, "2019-11-09T13:34:41Z"); - - Ok(()) - } -} diff --git a/services/aws-v4/src/key.rs b/services/aws-v4/src/key.rs new file mode 100644 index 0000000..0b4800d --- /dev/null +++ b/services/aws-v4/src/key.rs @@ -0,0 +1,47 @@ +use reqsign_core::time::{now, DateTime}; +use reqsign_core::utils::Redact; +use reqsign_core::Key; +use std::fmt::{Debug, Formatter}; + +/// Credential that holds the access_key and secret_key. +#[derive(Default, Clone)] +pub struct Credential { + /// Access key id for aws services. + pub access_key_id: String, + /// Secret access key for aws services. + pub secret_access_key: String, + /// Session token for aws services. + pub session_token: Option, + /// Expiration time for this credential. + pub expires_in: Option, +} + +impl Debug for Credential { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Credential") + .field("access_key_id", &Redact::from(&self.access_key_id)) + .field("secret_access_key", &Redact::from(&self.secret_access_key)) + .field("session_token", &Redact::from(&self.session_token)) + .field("expires_in", &self.expires_in) + .finish() + } +} + +impl Key for Credential { + fn is_valid(&self) -> bool { + if (self.access_key_id.is_empty() || self.secret_access_key.is_empty()) + && self.session_token.is_none() + { + return false; + } + // Take 120s as buffer to avoid edge cases. + if let Some(valid) = self + .expires_in + .map(|v| v > now() + chrono::TimeDelta::try_minutes(2).expect("in bounds")) + { + return valid; + } + + true + } +} diff --git a/services/aws-v4/src/lib.rs b/services/aws-v4/src/lib.rs index 7bf2722..018cadd 100644 --- a/services/aws-v4/src/lib.rs +++ b/services/aws-v4/src/lib.rs @@ -1,15 +1,15 @@ //! AWS service signer +mod constants; + mod config; pub use config::Config; - -mod credential; -pub use credential::AssumeRoleLoader; -pub use credential::Credential; -pub use credential::CredentialLoad; -pub use credential::DefaultLoader; - -mod signer; -pub use signer::Signer; - -mod constants; +mod key; +pub use key::Credential; +mod build; +pub use build::Builder; +mod load; +pub use load::*; + +pub const EMPTY_STRING_SHA256: &str = + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; diff --git a/services/aws-v4/src/load/assume_role.rs b/services/aws-v4/src/load/assume_role.rs new file mode 100644 index 0000000..fa3b7b4 --- /dev/null +++ b/services/aws-v4/src/load/assume_role.rs @@ -0,0 +1,175 @@ +use crate::constants::X_AMZ_CONTENT_SHA_256; +use crate::key::Credential; +use crate::load::utils::sts_endpoint; +use crate::{Config, EMPTY_STRING_SHA256}; +use anyhow::anyhow; +use async_trait::async_trait; +use bytes::Bytes; +use quick_xml::de; +use reqsign_core::time::parse_rfc3339; +use reqsign_core::{Context, Load, Signer}; +use serde::Deserialize; +use std::fmt::Write; +use std::sync::Arc; + +/// AssumeRoleLoader will load credential via assume role. +#[derive(Debug)] +pub struct AssumeRoleLoader { + config: Arc, + + sts_signer: Signer, +} + +impl AssumeRoleLoader { + /// Create a new assume role loader. + pub fn new(config: Arc, sts_signer: Signer) -> anyhow::Result { + Ok(Self { config, sts_signer }) + } +} + +#[async_trait] +impl Load for AssumeRoleLoader { + type Key = Credential; + + async fn load(&self, ctx: &Context) -> anyhow::Result> { + let role_arn =self.config.role_arn.clone().ok_or_else(|| { + anyhow!("assume role loader requires role_arn, but not found, please check your configuration") + })?; + + let role_session_name = &self.config.role_session_name; + + let endpoint = sts_endpoint(&self.config)?; + + // Construct request to AWS STS Service. + let mut url = format!("https://{endpoint}/?Action=AssumeRole&RoleArn={role_arn}&Version=2011-06-15&RoleSessionName={role_session_name}"); + if let Some(external_id) = &self.config.external_id { + write!(url, "&ExternalId={external_id}")?; + } + if let Some(duration_seconds) = &self.config.duration_seconds { + write!(url, "&DurationSeconds={duration_seconds}")?; + } + if let Some(tags) = &self.config.tags { + for (idx, (key, value)) in tags.iter().enumerate() { + let tag_index = idx + 1; + write!( + url, + "&Tags.member.{tag_index}.Key={key}&Tags.member.{tag_index}.Value={value}" + )?; + } + } + + let req = http::request::Request::builder() + .method("GET") + .uri(url) + .header( + http::header::CONTENT_TYPE.as_str(), + "application/x-www-form-urlencoded", + ) + // Set content sha to empty string. + .header(X_AMZ_CONTENT_SHA_256, EMPTY_STRING_SHA256) + .body(Bytes::new())?; + + let (mut parts, body) = req.into_parts(); + self.sts_signer.sign(&mut parts, None).await?; + let req = http::Request::from_parts(parts, body); + + let resp = ctx.http_send_as_string(req).await?; + if resp.status() != http::StatusCode::OK { + let content = resp.into_body(); + return Err(anyhow!("request to AWS STS Services failed: {content}")); + } + + let resp: AssumeRoleResponse = de::from_str(&resp.into_body())?; + let resp_cred = resp.result.credentials; + + let cred = Credential { + access_key_id: resp_cred.access_key_id, + secret_access_key: resp_cred.secret_access_key, + session_token: Some(resp_cred.session_token), + expires_in: Some(parse_rfc3339(&resp_cred.expiration)?), + }; + + Ok(Some(cred)) + } +} + +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct AssumeRoleResponse { + #[serde(rename = "AssumeRoleResult")] + result: AssumeRoleResult, +} + +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct AssumeRoleResult { + credentials: AssumeRoleCredentials, +} + +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct AssumeRoleCredentials { + access_key_id: String, + secret_access_key: String, + session_token: String, + expiration: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use quick_xml::de; + + #[test] + fn test_parse_assume_role_response() -> anyhow::Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let content = r#" + + Alice + + arn:aws:sts::123456789012:assumed-role/demo/TestAR + ARO123EXAMPLE123:TestAR + + + ASIAIOSFODNN7EXAMPLE + wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + + AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW + LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd + QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU + 9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz + +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA== + + 2019-11-09T13:34:41Z + + 6 + + + c6104cbe-af31-11e0-8154-cbc7ccf896c7 + +"#; + + let resp: AssumeRoleResponse = de::from_str(content).expect("xml deserialize must success"); + + assert_eq!( + &resp.result.credentials.access_key_id, + "ASIAIOSFODNN7EXAMPLE" + ); + assert_eq!( + &resp.result.credentials.secret_access_key, + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + ); + assert_eq!( + &resp.result.credentials.session_token, + "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW + LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd + QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU + 9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz + +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==" + ); + assert_eq!(&resp.result.credentials.expiration, "2019-11-09T13:34:41Z"); + + Ok(()) + } +} diff --git a/services/aws-v4/src/load/assume_role_with_web_identity.rs b/services/aws-v4/src/load/assume_role_with_web_identity.rs new file mode 100644 index 0000000..b143137 --- /dev/null +++ b/services/aws-v4/src/load/assume_role_with_web_identity.rs @@ -0,0 +1,137 @@ +use crate::load::utils::sts_endpoint; +use crate::{Config, Credential}; +use anyhow::anyhow; +use async_trait::async_trait; +use bytes::Bytes; +use quick_xml::de; +use reqsign_core::time::parse_rfc3339; +use reqsign_core::{Context, Load}; +use serde::Deserialize; +use std::sync::Arc; + +/// AssumeRoleLoader will load credential via assume role. +#[derive(Debug)] +pub struct AssumeRoleWithWebIdentityLoader { + config: Arc, +} + +impl AssumeRoleWithWebIdentityLoader { + /// Create a new `AssumeRoleWithWebIdentityLoader` instance. + pub fn new(cfg: Arc) -> Self { + Self { config: cfg } + } +} + +#[async_trait] +impl Load for AssumeRoleWithWebIdentityLoader { + type Key = Credential; + + async fn load(&self, ctx: &Context) -> anyhow::Result> { + let (token_file, role_arn) = + match (&self.config.web_identity_token_file, &self.config.role_arn) { + (Some(token_file), Some(role_arn)) => (token_file, role_arn), + _ => return Ok(None), + }; + + let token = ctx.file_read_as_string(token_file).await?; + let role_session_name = &self.config.role_session_name; + + let endpoint = sts_endpoint(&self.config)?; + + // Construct request to AWS STS Service. + let url = format!("https://{endpoint}/?Action=AssumeRoleWithWebIdentity&RoleArn={role_arn}&WebIdentityToken={token}&Version=2011-06-15&RoleSessionName={role_session_name}"); + let req = http::request::Request::builder() + .method("GET") + .uri(url) + .header( + http::header::CONTENT_TYPE.as_str(), + "application/x-www-form-urlencoded", + ) + .body(Bytes::new())?; + + let resp = ctx.http_send_as_string(req).await?; + if resp.status() != http::StatusCode::OK { + let content = resp.into_body(); + return Err(anyhow!("request to AWS STS Services failed: {content}")); + } + + let resp: AssumeRoleWithWebIdentityResponse = de::from_str(&resp.into_body())?; + let resp_cred = resp.result.credentials; + + let cred = Credential { + access_key_id: resp_cred.access_key_id, + secret_access_key: resp_cred.secret_access_key, + session_token: Some(resp_cred.session_token), + expires_in: Some(parse_rfc3339(&resp_cred.expiration)?), + }; + + Ok(Some(cred)) + } +} + +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct AssumeRoleWithWebIdentityResponse { + #[serde(rename = "AssumeRoleWithWebIdentityResult")] + result: AssumeRoleWithWebIdentityResult, +} + +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct AssumeRoleWithWebIdentityResult { + credentials: AssumeRoleWithWebIdentityCredentials, +} + +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct AssumeRoleWithWebIdentityCredentials { + access_key_id: String, + secret_access_key: String, + session_token: String, + expiration: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + + #[test] + fn test_parse_assume_role_with_web_identity_response() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let content = r#" + + test_audience + + role_id:reqsign + arn:aws:sts::123:assumed-role/reqsign/reqsign + + arn:aws:iam::123:oidc-provider/example.com/ + + access_key_id + secret_access_key + session_token + 2022-05-25T11:45:17Z + + subject + + + b1663ad1-23ab-45e9-b465-9af30b202eba + +"#; + + let resp: AssumeRoleWithWebIdentityResponse = + de::from_str(content).expect("xml deserialize must success"); + + assert_eq!(&resp.result.credentials.access_key_id, "access_key_id"); + assert_eq!( + &resp.result.credentials.secret_access_key, + "secret_access_key" + ); + assert_eq!(&resp.result.credentials.session_token, "session_token"); + assert_eq!(&resp.result.credentials.expiration, "2022-05-25T11:45:17Z"); + + Ok(()) + } +} diff --git a/services/aws-v4/src/load/config.rs b/services/aws-v4/src/load/config.rs new file mode 100644 index 0000000..fe5a1e2 --- /dev/null +++ b/services/aws-v4/src/load/config.rs @@ -0,0 +1,36 @@ +use crate::{Config, Credential}; +use async_trait::async_trait; +use reqsign_core::{Context, Load}; +use std::sync::Arc; + +/// TODO: we should support refresh from config file. +#[derive(Debug)] +pub struct ConfigLoader { + config: Arc, +} + +impl ConfigLoader { + /// Create a new `ConfigLoader` instance. + pub fn new(cfg: Arc) -> Self { + Self { config: cfg } + } +} + +#[async_trait] +impl Load for ConfigLoader { + type Key = Credential; + + async fn load(&self, _: &Context) -> anyhow::Result> { + let (Some(ak), Some(sk)) = (&self.config.access_key_id, &self.config.secret_access_key) + else { + return Ok(None); + }; + + Ok(Some(Credential { + access_key_id: ak.clone(), + secret_access_key: sk.clone(), + session_token: self.config.session_token.clone(), + expires_in: None, + })) + } +} diff --git a/services/aws-v4/src/load/default.rs b/services/aws-v4/src/load/default.rs new file mode 100644 index 0000000..8ae7463 --- /dev/null +++ b/services/aws-v4/src/load/default.rs @@ -0,0 +1,240 @@ +use crate::load::config::ConfigLoader; +use crate::load::{AssumeRoleWithWebIdentityLoader, IMDSv2Loader}; +use crate::{Config, Credential}; +use async_trait::async_trait; +use reqsign_core::{Context, Load}; +use std::sync::Arc; + +/// DefaultLoader is a loader that will try to load credential via default chains. +/// +/// Resolution order: +/// +/// 1. Environment variables +/// 2. Shared config (`~/.aws/config`, `~/.aws/credentials`) +/// 3. Web Identity Tokens +/// 4. ECS (IAM Roles for Tasks) & General HTTP credentials (TODO) +/// 5. EC2 IMDSv2 +#[derive(Debug)] +pub struct DefaultLoader { + config_loader: ConfigLoader, + assume_role_with_web_identity_loader: AssumeRoleWithWebIdentityLoader, + imds_v2_loader: IMDSv2Loader, +} + +impl DefaultLoader { + /// Create a new `DefaultLoader` instance. + pub fn new(config: Arc) -> Self { + let config_loader = ConfigLoader::new(config.clone()); + let assume_role_with_web_identity_loader = + AssumeRoleWithWebIdentityLoader::new(config.clone()); + let imds_v2_loader = IMDSv2Loader::new(config.clone()); + + Self { + config_loader, + assume_role_with_web_identity_loader, + imds_v2_loader, + } + } +} + +#[async_trait] +impl Load for DefaultLoader { + type Key = Credential; + + async fn load(&self, ctx: &Context) -> anyhow::Result> { + if let Some(cred) = self.config_loader.load(ctx).await? { + return Ok(Some(cred)); + } + + if let Some(cred) = self.assume_role_with_web_identity_loader.load(ctx).await? { + return Ok(Some(cred)); + } + + if let Some(cred) = self.imds_v2_loader.load(ctx).await? { + return Ok(Some(cred)); + } + + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::constants::{ + AWS_ACCESS_KEY_ID, AWS_CONFIG_FILE, AWS_SECRET_ACCESS_KEY, AWS_SHARED_CREDENTIALS_FILE, + }; + use reqsign_core::StaticEnv; + use reqsign_file_read_tokio::TokioFileRead; + use reqsign_http_send_reqwest::ReqwestHttpSend; + use std::collections::HashMap; + use std::env; + + #[tokio::test] + async fn test_credential_env_loader_without_env() { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let ctx = ctx.with_env(StaticEnv { + home_dir: None, + envs: HashMap::new(), + }); + + let cfg = Config { + ec2_metadata_disabled: true, + ..Default::default() + }; + + let l = DefaultLoader::new(Arc::new(cfg)); + let x = l.load(&ctx).await.expect("load must succeed"); + assert!(x.is_none()); + } + + #[tokio::test] + async fn test_credential_env_loader_with_env() { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let ctx = ctx.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + (AWS_ACCESS_KEY_ID.to_string(), "access_key_id".to_string()), + ( + AWS_SECRET_ACCESS_KEY.to_string(), + "secret_access_key".to_string(), + ), + ]), + }); + + let l = DefaultLoader::new(Arc::new(Config::default().from_env(&ctx))); + let x = l.load(&ctx).await.expect("load must succeed"); + + let x = x.expect("must load succeed"); + assert_eq!("access_key_id", x.access_key_id); + assert_eq!("secret_access_key", x.secret_access_key); + } + + #[tokio::test] + async fn test_credential_profile_loader_from_config() { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let ctx = ctx.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + ( + AWS_CONFIG_FILE.to_string(), + format!( + "{}/testdata/default_config", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ), + ), + ( + AWS_SHARED_CREDENTIALS_FILE.to_string(), + format!( + "{}/testdata/not_exist", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ), + ), + ]), + }); + + let l = DefaultLoader::new( + Config::default() + .from_env(&ctx) + .from_profile(&ctx) + .await + .into(), + ); + let x = l.load(&ctx).await.unwrap().unwrap(); + assert_eq!("config_access_key_id", x.access_key_id); + assert_eq!("config_secret_access_key", x.secret_access_key); + } + + #[tokio::test] + async fn test_credential_profile_loader_from_shared() { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let ctx = ctx.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + ( + AWS_CONFIG_FILE.to_string(), + format!( + "{}/testdata/not_exist", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ), + ), + ( + AWS_SHARED_CREDENTIALS_FILE.to_string(), + format!( + "{}/testdata/default_credential", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ), + ), + ]), + }); + + let l = DefaultLoader::new( + Config::default() + .from_env(&ctx) + .from_profile(&ctx) + .await + .into(), + ); + let x = l.load(&ctx).await.unwrap().unwrap(); + assert_eq!("shared_access_key_id", x.access_key_id); + assert_eq!("shared_secret_access_key", x.secret_access_key); + } + + /// AWS_SHARED_CREDENTIALS_FILE should be taken first. + #[tokio::test] + async fn test_credential_profile_loader_from_both() { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let ctx = ctx.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + ( + AWS_CONFIG_FILE.to_string(), + format!( + "{}/testdata/default_config", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ), + ), + ( + AWS_SHARED_CREDENTIALS_FILE.to_string(), + format!( + "{}/testdata/default_credential", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ), + ), + ]), + }); + + let l = DefaultLoader::new( + Config::default() + .from_env(&ctx) + .from_profile(&ctx) + .await + .into(), + ); + let x = l.load(&ctx).await.expect("load must success").unwrap(); + assert_eq!("shared_access_key_id", x.access_key_id); + assert_eq!("shared_secret_access_key", x.secret_access_key); + } +} diff --git a/services/aws-v4/src/load/imds.rs b/services/aws-v4/src/load/imds.rs new file mode 100644 index 0000000..c5be1fd --- /dev/null +++ b/services/aws-v4/src/load/imds.rs @@ -0,0 +1,153 @@ +use crate::{Config, Credential}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use bytes::Bytes; +use http::header::CONTENT_LENGTH; +use http::Method; +use reqsign_core::time::{now, parse_rfc3339, DateTime}; +use reqsign_core::{Context, Load}; +use serde::Deserialize; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Clone)] +pub struct IMDSv2Loader { + config: Arc, + token: Arc>, +} + +impl IMDSv2Loader { + /// Create a new `IMDSv2Loader` instance. + pub fn new(cfg: Arc) -> Self { + Self { + config: cfg, + token: Arc::new(Mutex::new((String::new(), DateTime::default()))), + } + } +} + +impl IMDSv2Loader { + async fn load_ec2_metadata_token(&self, ctx: &Context) -> Result { + { + let (token, expires_in) = self.token.lock().expect("lock poisoned").clone(); + if expires_in > now() { + return Ok(token); + } + } + + let url = "http://169.254.169.254/latest/api/token"; + let req = http::Request::builder() + .uri(url) + .method(Method::PUT) + .header(CONTENT_LENGTH, "0") + // 21600s (6h) is recommended by AWS. + .header("x-aws-ec2-metadata-token-ttl-seconds", "21600") + .body(Bytes::new())?; + let resp = ctx.http_send_as_string(req).await?; + if resp.status() != http::StatusCode::OK { + return Err(anyhow!( + "request to AWS EC2 Metadata Services failed: {}", + resp.body() + )); + } + let ec2_token = resp.into_body(); + // Set expires_in to 10 minutes to enforce re-read. + let expires_in = now() + chrono::TimeDelta::try_seconds(21600).expect("in bounds") + - chrono::TimeDelta::try_seconds(600).expect("in bounds"); + + { + *self.token.lock().expect("lock poisoned") = (ec2_token.clone(), expires_in); + } + + Ok(ec2_token) + } +} + +#[async_trait] +impl Load for IMDSv2Loader { + type Key = Credential; + + async fn load(&self, ctx: &Context) -> Result> { + // If ec2_metadata_disabled is set, return None. + if self.config.ec2_metadata_disabled { + return Ok(None); + } + + let token = self.load_ec2_metadata_token(ctx).await?; + + // List all credentials that node has. + let url = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"; + let req = http::Request::builder() + .uri(url) + .method(Method::GET) + // 21600s (6h) is recommended by AWS. + .header("x-aws-ec2-metadata-token", &token) + .body(Bytes::new())?; + let resp = ctx.http_send_as_string(req).await?; + if resp.status() != http::StatusCode::OK { + return Err(anyhow!( + "request to AWS EC2 Metadata Services failed: {}", + resp.body() + )); + } + + let profile_name = resp.into_body(); + + // Get the credentials via role_name. + let url = format!( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/{profile_name}" + ); + let req = http::Request::builder() + .uri(url) + .method(Method::GET) + // 21600s (6h) is recommended by AWS. + .header("x-aws-ec2-metadata-token", &token) + .body(Bytes::new())?; + + let resp = ctx.http_send_as_string(req).await?; + if resp.status() != http::StatusCode::OK { + return Err(anyhow!( + "request to AWS EC2 Metadata Services failed: {}", + resp.body() + )); + } + + let content = resp.into_body(); + let resp: Ec2MetadataIamSecurityCredentials = serde_json::from_str(&content)?; + if resp.code == "AssumeRoleUnauthorizedAccess" { + return Err(anyhow!( + "Incorrect IMDS/IAM configuration: [{}] {}. \ + Hint: Does this role have a trust relationship with EC2?", + resp.code, + resp.message + )); + } + if resp.code != "Success" { + return Err(anyhow!( + "Error retrieving credentials from IMDS: {} {}", + resp.code, + resp.message + )); + } + + let cred = Credential { + access_key_id: resp.access_key_id, + secret_access_key: resp.secret_access_key, + session_token: Some(resp.token), + expires_in: Some(parse_rfc3339(&resp.expiration)?), + }; + + Ok(Some(cred)) + } +} + +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct Ec2MetadataIamSecurityCredentials { + access_key_id: String, + secret_access_key: String, + token: String, + expiration: String, + + code: String, + message: String, +} diff --git a/services/aws-v4/src/load/mod.rs b/services/aws-v4/src/load/mod.rs new file mode 100644 index 0000000..d5d6bd2 --- /dev/null +++ b/services/aws-v4/src/load/mod.rs @@ -0,0 +1,16 @@ +mod assume_role; +pub use assume_role::AssumeRoleLoader; + +mod assume_role_with_web_identity; +pub use assume_role_with_web_identity::AssumeRoleWithWebIdentityLoader; + +mod config; +pub use config::ConfigLoader; + +mod default; +pub use default::DefaultLoader; + +mod imds; +pub use imds::IMDSv2Loader; + +mod utils; diff --git a/services/aws-v4/src/load/utils.rs b/services/aws-v4/src/load/utils.rs new file mode 100644 index 0000000..1050cf7 --- /dev/null +++ b/services/aws-v4/src/load/utils.rs @@ -0,0 +1,34 @@ +use crate::Config; +use anyhow::anyhow; + +/// Get the sts endpoint. +/// +/// The returning format may look like `sts.{region}.amazonaws.com` +/// +/// # Notes +/// +/// AWS could have different sts endpoint based on it's region. +/// We can check them by region name. +/// +/// ref: https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs +pub fn sts_endpoint(config: &Config) -> anyhow::Result { + // use regional sts if sts_regional_endpoints has been set. + if config.sts_regional_endpoints == "regional" { + let region = config.region.clone().ok_or_else(|| { + anyhow!("sts_regional_endpoints set to regional, but region is not set") + })?; + if region.starts_with("cn-") { + Ok(format!("sts.{region}.amazonaws.com.cn")) + } else { + Ok(format!("sts.{region}.amazonaws.com")) + } + } else { + let region = config.region.clone().unwrap_or_default(); + if region.starts_with("cn") { + // TODO: seems aws china doesn't support global sts? + Ok("sts.amazonaws.com.cn".to_string()) + } else { + Ok("sts.amazonaws.com".to_string()) + } + } +} diff --git a/services/aws-v4/tests/main.rs b/services/aws-v4/tests/main.rs index d6e3acf..cb94730 100644 --- a/services/aws-v4/tests/main.rs +++ b/services/aws-v4/tests/main.rs @@ -1,5 +1,7 @@ +use std::collections::HashMap; use std::env; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use anyhow::Result; @@ -9,17 +11,17 @@ use log::debug; use log::warn; use percent_encoding::utf8_percent_encode; use percent_encoding::NON_ALPHANUMERIC; -use reqsign_aws_v4::Config; -use reqsign_aws_v4::DefaultLoader; -use reqsign_aws_v4::Signer; -use reqsign_core::Context; +use reqsign_aws_v4::{AssumeRoleLoader, Config}; +use reqsign_aws_v4::{Builder, DefaultLoader}; +use reqsign_core::{Build, Context, Load, Signer, StaticEnv}; use reqsign_file_read_tokio::TokioFileRead; use reqsign_http_send_reqwest::ReqwestHttpSend; use reqwest::Client; use sha2::Digest; use sha2::Sha256; +use tokio::fs; -async fn init_signer() -> Option<(DefaultLoader, Signer)> { +async fn init_default_loader() -> Option<(Context, DefaultLoader, Builder)> { let _ = env_logger::builder().is_test(true).try_init(); dotenv::from_filename("../../../.env").ok(); @@ -49,24 +51,22 @@ async fn init_signer() -> Option<(DefaultLoader, Signer)> { let region = config.region.as_deref().unwrap().to_string(); - let loader = DefaultLoader::new(Client::new(), config); + let loader = DefaultLoader::new(config.into()); - let signer = Signer::new( + let builder = Builder::new( &env::var("REQSIGN_AWS_V4_SERVICE").expect("env REQSIGN_AWS_V4_SERVICE must set"), ®ion, ); - Some((loader, signer)) + Some((context, loader, builder)) } #[tokio::test] async fn test_head_object() -> Result<()> { - let signer = init_signer().await; - if signer.is_none() { + let Some((ctx, loader, builder)) = init_default_loader().await else { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); - } - let (loader, signer) = signer.unwrap(); + }; let url = &env::var("REQSIGN_AWS_V4_URL").expect("env REQSIGN_AWS_V4_URL must set"); @@ -75,15 +75,16 @@ async fn test_head_object() -> Result<()> { *req.uri_mut() = http::Uri::from_str(&format!("{}/{}", url, "not_exist_file"))?; let cred = loader - .load() + .load(&ctx) .await .expect("load request must success") .unwrap(); let req = { let (mut parts, body) = req.into_parts(); - signer - .sign(&mut parts, &cred) + builder + .build(&ctx, &mut parts, Some(&cred), None) + .await .expect("sign request must success"); Request::from_parts(parts, body) }; @@ -103,12 +104,10 @@ async fn test_head_object() -> Result<()> { #[tokio::test] async fn test_put_object_with_query() -> Result<()> { - let signer = init_signer().await; - if signer.is_none() { + let Some((ctx, loader, builder)) = init_default_loader().await else { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); - } - let (loader, signer) = signer.unwrap(); + }; let url = &env::var("REQSIGN_AWS_V4_URL").expect("env REQSIGN_AWS_V4_URL must set"); let body = "Hello, World!"; @@ -123,15 +122,16 @@ async fn test_put_object_with_query() -> Result<()> { *req.uri_mut() = http::Uri::from_str(&format!("{}/{}", url, "put_object_test"))?; let cred = loader - .load() + .load(&ctx) .await .expect("load request must success") .unwrap(); let req = { let (mut parts, body) = req.into_parts(); - signer - .sign(&mut parts, &cred) + builder + .build(&ctx, &mut parts, Some(&cred), None) + .await .expect("sign request must success"); Request::from_parts(parts, body) }; @@ -155,12 +155,10 @@ async fn test_put_object_with_query() -> Result<()> { #[tokio::test] async fn test_get_object_with_query() -> Result<()> { - let signer = init_signer().await; - if signer.is_none() { + let Some((ctx, loader, builder)) = init_default_loader().await else { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); - } - let (loader, signer) = signer.unwrap(); + }; let url = &env::var("REQSIGN_AWS_V4_URL").expect("env REQSIGN_AWS_V4_URL must set"); @@ -169,15 +167,21 @@ async fn test_get_object_with_query() -> Result<()> { *req.uri_mut() = http::Uri::from_str(&format!("{}/{}", url, "not_exist_file"))?; let cred = loader - .load() + .load(&ctx) .await .expect("load request must success") .unwrap(); let req = { let (mut parts, body) = req.into_parts(); - signer - .sign_query(&mut parts, Duration::from_secs(3600), &cred) + builder + .build( + &ctx, + &mut parts, + Some(&cred), + Some(Duration::from_secs(3600)), + ) + .await .expect("sign request must success"); Request::from_parts(parts, body) }; @@ -197,12 +201,10 @@ async fn test_get_object_with_query() -> Result<()> { #[tokio::test] async fn test_head_object_with_special_characters() -> Result<()> { - let signer = init_signer().await; - if signer.is_none() { + let Some((ctx, loader, builder)) = init_default_loader().await else { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); - } - let (loader, signer) = signer.unwrap(); + }; let url = &env::var("REQSIGN_AWS_V4_URL").expect("env REQSIGN_AWS_V4_URL must set"); @@ -215,15 +217,16 @@ async fn test_head_object_with_special_characters() -> Result<()> { ))?; let cred = loader - .load() + .load(&ctx) .await .expect("load request must success") .unwrap(); let req = { let (mut parts, body) = req.into_parts(); - signer - .sign(&mut parts, &cred) + builder + .build(&ctx, &mut parts, Some(&cred), None) + .await .expect("sign request must success"); Request::from_parts(parts, body) }; @@ -243,12 +246,10 @@ async fn test_head_object_with_special_characters() -> Result<()> { #[tokio::test] async fn test_head_object_with_encoded_characters() -> Result<()> { - let signer = init_signer().await; - if signer.is_none() { + let Some((ctx, loader, builder)) = init_default_loader().await else { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); - } - let (loader, signer) = signer.unwrap(); + }; let url = &env::var("REQSIGN_AWS_V4_URL").expect("env REQSIGN_AWS_V4_URL must set"); @@ -261,15 +262,16 @@ async fn test_head_object_with_encoded_characters() -> Result<()> { ))?; let cred = loader - .load() + .load(&ctx) .await .expect("load request must success") .unwrap(); let req = { let (mut parts, body) = req.into_parts(); - signer - .sign(&mut parts, &cred) + builder + .build(&ctx, &mut parts, Some(&cred), None) + .await .expect("sign request must success"); Request::from_parts(parts, body) }; @@ -289,12 +291,10 @@ async fn test_head_object_with_encoded_characters() -> Result<()> { #[tokio::test] async fn test_list_bucket() -> Result<()> { - let signer = init_signer().await; - if signer.is_none() { + let Some((ctx, loader, builder)) = init_default_loader().await else { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); - } - let (loader, signer) = signer.unwrap(); + }; let url = &env::var("REQSIGN_AWS_V4_URL").expect("env REQSIGN_AWS_V4_URL must set"); @@ -304,15 +304,16 @@ async fn test_list_bucket() -> Result<()> { http::Uri::from_str(&format!("{url}?list-type=2&delimiter=/&encoding-type=url"))?; let cred = loader - .load() + .load(&ctx) .await .expect("load request must success") .unwrap(); let req = { let (mut parts, body) = req.into_parts(); - signer - .sign(&mut parts, &cred) + builder + .build(&ctx, &mut parts, Some(&cred), None) + .await .expect("sign request must success"); Request::from_parts(parts, body) }; @@ -329,3 +330,180 @@ async fn test_list_bucket() -> Result<()> { assert_eq!(StatusCode::OK, resp.status()); Ok(()) } + +#[tokio::test] +async fn test_signer_with_web_loader() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + dotenv::from_filename("../../../.env").ok(); + + if env::var("REQSIGN_AWS_S3_TEST").is_err() || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on" + { + return Ok(()); + } + + // Ignore test if role_arn not set + let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") { + v + } else { + return Ok(()); + }; + + let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist"); + + let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist"); + let file_path = format!( + "{}/testdata/web_identity_token_file", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ); + fs::write(&file_path, github_token).await?; + + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + ("AWS_REGION".to_string(), region.to_string()), + ("AWS_ROLE_ARN".to_string(), role_arn.to_string()), + ( + "AWS_WEB_IDENTITY_TOKEN_FILE".to_string(), + file_path.to_string(), + ), + ]), + }); + + let config = Config::default().from_env(&context); + let loader = DefaultLoader::new(config.into()); + + let builder = Builder::new("s3", ®ion); + + let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); + let mut req = Request::new(""); + *req.method_mut() = http::Method::GET; + *req.uri_mut() = http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); + + let cred = loader + .load(&context) + .await + .expect("credential must be valid") + .unwrap(); + + let (mut req, body) = req.into_parts(); + builder + .build(&context, &mut req, Some(&cred), None) + .await + .expect("sign must success"); + let req = Request::from_parts(req, body); + + debug!("signed request url: {:?}", req.uri().to_string()); + debug!("signed request: {:?}", req); + + let client = Client::new(); + let resp = client.execute(req.try_into().unwrap()).await.unwrap(); + + let status = resp.status(); + debug!("got response: {:?}", resp); + debug!("got response content: {:?}", resp.text().await.unwrap()); + assert_eq!(status, StatusCode::NOT_FOUND); + Ok(()) +} + +#[tokio::test] +async fn test_signer_with_web_loader_assume_role() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + dotenv::from_filename("../../../.env").ok(); + + if env::var("REQSIGN_AWS_S3_TEST").is_err() || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on" + { + return Ok(()); + } + + // Ignore test if role_arn not set + let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ROLE_ARN") { + v + } else { + return Ok(()); + }; + // Ignore test if assume_role_arn not set + let assume_role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") { + v + } else { + return Ok(()); + }; + + let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist"); + + let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist"); + let file_path = format!( + "{}/testdata/web_identity_token_file", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ); + fs::write(&file_path, github_token).await?; + + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + ("AWS_REGION".to_string(), region.to_string()), + ("AWS_ROLE_ARN".to_string(), role_arn.to_string()), + ( + "AWS_WEB_IDENTITY_TOKEN_FILE".to_string(), + file_path.to_string(), + ), + ]), + }); + + let cfg = Config { + ec2_metadata_disabled: true, + ..Default::default() + }; + let cfg: Arc = cfg.from_env(&context).into(); + + let default_loader = DefaultLoader::new(cfg.clone()); + let sts_signer = Signer::new( + context.clone(), + default_loader, + Builder::new("sts", ®ion), + ); + + let cfg = Config { + role_arn: Some(assume_role_arn.clone()), + region: Some(region.clone()), + sts_regional_endpoints: "regional".to_string(), + ..Default::default() + }; + let loader = + AssumeRoleLoader::new(cfg.into(), sts_signer).expect("AssumeRoleLoader must be valid"); + + let builder = Builder::new("s3", ®ion); + let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); + let mut req = Request::new(""); + *req.method_mut() = http::Method::GET; + *req.uri_mut() = http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); + let cred = loader + .load(&context) + .await + .expect("credential must be valid") + .unwrap(); + + let (mut parts, body) = req.into_parts(); + builder + .build(&context, &mut parts, Some(&cred), None) + .await + .expect("sign must success"); + let req = Request::from_parts(parts, body); + + debug!("signed request url: {:?}", req.uri().to_string()); + debug!("signed request: {:?}", req); + let client = Client::new(); + let resp = client.execute(req.try_into().unwrap()).await.unwrap(); + let status = resp.status(); + debug!("got response: {:?}", resp); + debug!("got response content: {:?}", resp.text().await.unwrap()); + assert_eq!(status, StatusCode::NOT_FOUND); + Ok(()) +}