Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent http connection race condition after restoring from Lambda SnapStart #569

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ The readiness check port/path and traffic port can be configured using environme
| AWS_LWA_PASS_THROUGH_PATH | the path for receiving event payloads that are passed through from non-http triggers | "/events" |
| AWS_LWA_AUTHORIZATION_SOURCE | a header name to be replaced to `Authorization` | None |
| AWS_LWA_ERROR_STATUS_CODES | comma-separated list of HTTP status codes that will cause Lambda invocations to fail (e.g. "500,502-504,422") | None |
| AWS_LWA_CLIENT_IDLE_TIMEOUT_MS | HTTP client idle timeout in milliseconds | "4000" |

> **Note:**
> We use "AWS_LWA_" prefix to namespacing all environment variables used by Lambda Web Adapter. The original ones will be supported until we reach version 1.0.
Expand Down Expand Up @@ -140,6 +141,8 @@ Please check out [FastAPI with Response Streaming](examples/fastapi-response-str

**AWS_LWA_ERROR_STATUS_CODES** - A comma-separated list of HTTP status codes that will cause Lambda invocations to fail. Supports individual codes and ranges (e.g. "500,502-504,422"). When the web application returns any of these status codes, the Lambda invocation will fail and trigger error handling behaviors like retries or DLQ processing. This is useful for treating certain HTTP errors as Lambda execution failures. This feature is disabled by default.

**AWS_LWA_CLIENT_IDLE_TIMEOUT_MS** - HTTP client idle timeout in milliseconds. The default is 4000 milliseconds.

## Request Context

**Request Context** is metadata API Gateway sends to Lambda for a request. It usually contains requestId, requestTime, apiId, identity, and authorizer. Identity and authorizer are useful to get client identity for authorization. API Gateway Developer Guide contains more details [here](https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format).
Expand Down
60 changes: 53 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use lambda_http::Body;
pub use lambda_http::Error;
use lambda_http::{Request, RequestExt, Response};
use readiness::Checkpoint;
use std::fmt::Debug;
use std::{
env,
future::Future,
Expand All @@ -27,7 +26,9 @@ use std::{
},
time::Duration,
};
use tokio::{net::TcpStream, time::timeout};
use std::{fmt::Debug, time::SystemTime};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_retry::{strategy::FixedInterval, Retry};
use tower::{Service, ServiceBuilder};
use tower_http::compression::CompressionLayer;
Expand Down Expand Up @@ -81,6 +82,7 @@ pub struct AdapterOptions {
pub invoke_mode: LambdaInvokeMode,
pub authorization_source: Option<String>,
pub error_status_codes: Option<Vec<u16>>,
pub client_idle_timeout_ms: u64,
}

impl Default for AdapterOptions {
Expand Down Expand Up @@ -122,6 +124,10 @@ impl Default for AdapterOptions {
error_status_codes: env::var("AWS_LWA_ERROR_STATUS_CODES")
.ok()
.map(|codes| parse_status_codes(&codes)),
client_idle_timeout_ms: env::var("AWS_LWA_CLIENT_IDLE_TIMEOUT_MS")
.ok()
.map(|s| s.parse().unwrap())
.unwrap_or(4000),
}
}
}
Expand Down Expand Up @@ -170,17 +176,25 @@ pub struct Adapter<C, B> {
invoke_mode: LambdaInvokeMode,
authorization_source: Option<String>,
error_status_codes: Option<Vec<u16>>,
client_idle_timeout_ms: u64,
// be sure to use `SystemTime` (CLOCK_REALTIME) instead of `Duration` (CLOCK_MONOTONIC)
// to avoid issues when restored from Lambda SnapStart
last_invoke: SystemTime,
}

impl Adapter<HttpConnector, Body> {
fn new_client(timeout_ms: u64) -> Arc<Client<HttpConnector, Body>> {
Arc::new(
Client::builder(hyper_util::rt::TokioExecutor::new())
.pool_idle_timeout(Duration::from_millis(timeout_ms))
.build(HttpConnector::new()),
)
}

/// Create a new HTTP Adapter instance.
/// This function initializes a new HTTP client
/// to talk with the web server.
pub fn new(options: &AdapterOptions) -> Adapter<HttpConnector, Body> {
let client = Client::builder(hyper_util::rt::TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(4))
.build(HttpConnector::new());

let schema = "http";

let healthcheck_url = format!(
Expand All @@ -195,7 +209,7 @@ impl Adapter<HttpConnector, Body> {
.unwrap();

Adapter {
client: Arc::new(client),
client: Self::new_client(options.client_idle_timeout_ms),
healthcheck_url,
healthcheck_protocol: options.readiness_check_protocol,
healthcheck_min_unhealthy_status: options.readiness_check_min_unhealthy_status,
Expand All @@ -208,6 +222,9 @@ impl Adapter<HttpConnector, Body> {
invoke_mode: options.invoke_mode,
authorization_source: options.authorization_source.clone(),
error_status_codes: options.error_status_codes.clone(),
client_idle_timeout_ms: options.client_idle_timeout_ms,
// it's ok to use `now` here since there is no connections in the connection pool yet
last_invoke: SystemTime::now(),
}
}
}
Expand Down Expand Up @@ -403,6 +420,15 @@ impl Adapter<HttpConnector, Body> {

Ok(app_response)
}

/// Return whether the client has been idle for longer than the [`Self::client_idle_timeout_ms`].
fn client_timeout_has_expired(&self) -> bool {
self.last_invoke
.elapsed()
.map(|d| d.as_millis() > self.client_idle_timeout_ms.into())
// if the last_invoke is in the future, it's ok to re-use the client
.unwrap_or(false)
}
}

/// Implement a `Tower.Service` that sends the requests
Expand All @@ -417,7 +443,15 @@ impl Service<Request> for Adapter<HttpConnector, Body> {
}

fn call(&mut self, event: Request) -> Self::Future {
if self.client_timeout_has_expired() {
// client timeout, create a new client with a new connection pool.
// this is to prevent the pool from using a to-be-disconnected connection after restoring from Lambda SnapStart
tracing::debug!("Client timeout, creating a new client");
self.client = Self::new_client(self.client_idle_timeout_ms);
}

let adapter = self.clone();
self.last_invoke = SystemTime::now();
Box::pin(async move { adapter.fetch_response(event).await })
}
}
Expand Down Expand Up @@ -537,4 +571,16 @@ mod tests {
// Assert app server's healthcheck endpoint got called
healthcheck.assert();
}

#[test]
fn test_client_idle_timeout() {
let mut adapter = Adapter::new(&AdapterOptions::default());
assert!(!adapter.client_timeout_has_expired());

adapter.last_invoke = SystemTime::now() - Duration::from_millis(5000);
assert!(adapter.client_timeout_has_expired());

adapter.last_invoke = SystemTime::now() + Duration::from_millis(5000);
assert!(!adapter.client_timeout_has_expired());
}
}