Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/axum-key-value-store/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn app() -> Router {
)
.sensitive_response_headers(sensitive_headers)
// Set a timeout
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(10)))
// Compress responses
.compression()
// Set a `Content-Type` if there isn't one already.
Expand Down
16 changes: 8 additions & 8 deletions tower-http/src/timeout/mod.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
//! Middleware that applies a timeout to requests.
//!
//! If the request does not complete within the specified timeout it will be aborted and a `408
//! Request Timeout` response will be sent.
//! If the request does not complete within the specified timeout, it will be aborted and a
//! response with an empty body and a custom status code will be returned.
//!
//! # Differences from `tower::timeout`
//!
//! tower's [`Timeout`](tower::timeout::Timeout) middleware uses an error to signal timeout, i.e.
//! it changes the error type to [`BoxError`](tower::BoxError). For HTTP services that is rarely
//! what you want as returning errors will terminate the connection without sending a response.
//!
//! This middleware won't change the error type and instead return a `408 Request Timeout`
//! response. That means if your service's error type is [`Infallible`] it will still be
//! [`Infallible`] after applying this middleware.
//! This middleware won't change the error type and instead returns a response with an empty body
//! and the specified status code. That means if your service's error type is [`Infallible`], it will
//! still be [`Infallible`] after applying this middleware.
//!
//! # Example
//!
//! ```
//! use http::{Request, Response};
//! use http::{Request, Response, StatusCode};
//! use http_body_util::Full;
//! use bytes::Bytes;
//! use std::{convert::Infallible, time::Duration};
Expand All @@ -31,8 +31,8 @@
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let svc = ServiceBuilder::new()
//! // Timeout requests after 30 seconds
//! .layer(TimeoutLayer::new(Duration::from_secs(30)))
//! // Timeout requests after 30 seconds with the specified status code
//! .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(30)))
//! .service_fn(handle);
//! # Ok(())
//! # }
Expand Down
141 changes: 133 additions & 8 deletions tower-http/src/timeout/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,81 @@ use tower_service::Service;
#[derive(Debug, Clone, Copy)]
pub struct TimeoutLayer {
timeout: Duration,
status_code: StatusCode,
}

impl TimeoutLayer {
/// Creates a new [`TimeoutLayer`].
///
/// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
/// To customize the response status code, use the `with_status_code` method.
#[deprecated(since = "0.6.7", note = "Use `TimeoutLayer::with_status_code` instead")]
pub fn new(timeout: Duration) -> Self {
TimeoutLayer { timeout }
Self::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
}

/// Creates a new [`TimeoutLayer`] with the specified status code for the timeout response.
pub fn with_status_code(status_code: StatusCode, timeout: Duration) -> Self {
Self {
timeout,
status_code,
}
}
}

impl<S> Layer<S> for TimeoutLayer {
type Service = Timeout<S>;

fn layer(&self, inner: S) -> Self::Service {
Timeout::new(inner, self.timeout)
Timeout::with_status_code(inner, self.status_code, self.timeout)
}
}

/// Middleware which apply a timeout to requests.
///
/// If the request does not complete within the specified timeout it will be aborted and a `408
/// Request Timeout` response will be sent.
///
/// See the [module docs](super) for an example.
#[derive(Debug, Clone, Copy)]
pub struct Timeout<S> {
inner: S,
timeout: Duration,
status_code: StatusCode,
}

impl<S> Timeout<S> {
/// Creates a new [`Timeout`].
///
/// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
/// To customize the response status code, use the `with_status_code` method.
#[deprecated(since = "0.6.7", note = "Use `Timeout::with_status_code` instead")]
pub fn new(inner: S, timeout: Duration) -> Self {
Self { inner, timeout }
Self::with_status_code(inner, StatusCode::REQUEST_TIMEOUT, timeout)
}

/// Creates a new [`Timeout`] with the specified status code for the timeout response.
pub fn with_status_code(inner: S, status_code: StatusCode, timeout: Duration) -> Self {
Self {
inner,
timeout,
status_code,
}
}

define_inner_service_accessors!();

/// Returns a new [`Layer`] that wraps services with a `Timeout` middleware.
///
/// [`Layer`]: tower_layer::Layer
#[deprecated(
since = "0.6.7",
note = "Use `Timeout::layer_with_status_code` instead"
)]
pub fn layer(timeout: Duration) -> TimeoutLayer {
TimeoutLayer::new(timeout)
TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
}

/// Returns a new [`Layer`] that wraps services with a `Timeout` middleware with the specified status code.
pub fn layer_with_status_code(status_code: StatusCode, timeout: Duration) -> TimeoutLayer {
TimeoutLayer::with_status_code(status_code, timeout)
}
}

Expand All @@ -81,6 +114,7 @@ where
ResponseFuture {
inner: self.inner.call(req),
sleep,
status_code: self.status_code,
}
}
}
Expand All @@ -92,6 +126,7 @@ pin_project! {
inner: F,
#[pin]
sleep: Sleep,
status_code: StatusCode,
}
}

Expand All @@ -107,7 +142,7 @@ where

if this.sleep.poll(cx).is_ready() {
let mut res = Response::new(B::default());
*res.status_mut() = StatusCode::REQUEST_TIMEOUT;
*res.status_mut() = *this.status_code;
return Poll::Ready(Ok(res));
}

Expand Down Expand Up @@ -269,3 +304,93 @@ where
Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body))))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::Body;
use http::{Request, Response, StatusCode};
use std::time::Duration;
use tower::{BoxError, ServiceBuilder, ServiceExt};

#[tokio::test]
async fn request_completes_within_timeout() {
let mut service = ServiceBuilder::new()
.layer(TimeoutLayer::with_status_code(
StatusCode::GATEWAY_TIMEOUT,
Duration::from_secs(1),
))
.service_fn(fast_handler);

let request = Request::get("/").body(Body::empty()).unwrap();
let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn timeout_middleware_with_custom_status_code() {
let timeout_service = Timeout::with_status_code(
tower::service_fn(slow_handler),
StatusCode::REQUEST_TIMEOUT,
Duration::from_millis(10),
);

let mut service = ServiceBuilder::new().service(timeout_service);

let request = Request::get("/").body(Body::empty()).unwrap();
let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
}

#[tokio::test]
async fn timeout_response_has_empty_body() {
let mut service = ServiceBuilder::new()
.layer(TimeoutLayer::with_status_code(
StatusCode::GATEWAY_TIMEOUT,
Duration::from_millis(10),
))
.service_fn(slow_handler);

let request = Request::get("/").body(Body::empty()).unwrap();
let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT);

// Verify the body is empty (default)
use http_body_util::BodyExt;
let body = res.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
assert!(bytes.is_empty());
}

#[tokio::test]
async fn deprecated_new_method_compatibility() {
#[allow(deprecated)]
let layer = TimeoutLayer::new(Duration::from_millis(10));

let mut service = ServiceBuilder::new().layer(layer).service_fn(slow_handler);

let request = Request::get("/").body(Body::empty()).unwrap();
let res = service.ready().await.unwrap().call(request).await.unwrap();

// Should use default 408 status code
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
}

async fn slow_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap())
}

async fn fast_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap())
}
}