Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,16 @@ pub trait ServiceBuilderExt<L>: sealed::Sealed<L> + Sized {
fn trim_trailing_slash(
self,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;

/// Append trailing slash to paths.
///
/// See [`tower_http::normalize_path`] for more details.
///
/// [`tower_http::normalize_path`]: crate::normalize_path
#[cfg(feature = "normalize-path")]
fn append_trailing_slash(
self,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;
}

impl<L> sealed::Sealed<L> for ServiceBuilder<L> {}
Expand Down Expand Up @@ -596,4 +606,11 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash())
}

#[cfg(feature = "normalize-path")]
fn append_trailing_slash(
self,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
self.layer(crate::normalize_path::NormalizePathLayer::append_trailing_slash())
}
}
200 changes: 177 additions & 23 deletions tower-http/src/normalize_path.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
//! Middleware that normalizes paths.
//!
//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
//! will be changed to `/foo` before reaching the inner service.
//!
//! # Example
//!
//! ```
Expand Down Expand Up @@ -45,27 +42,53 @@ use std::{
use tower_layer::Layer;
use tower_service::Service;

/// Different modes of normalizing paths
#[derive(Debug, Copy, Clone)]
enum NormalizeMode {
/// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo
Trim,
/// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/
Append,
}

/// Layer that applies [`NormalizePath`] which normalizes paths.
///
/// See the [module docs](self) for more details.
#[derive(Debug, Copy, Clone)]
pub struct NormalizePathLayer {}
pub struct NormalizePathLayer {
mode: NormalizeMode,
}

impl NormalizePathLayer {
/// Create a new [`NormalizePathLayer`].
///
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
/// will be changed to `/foo` before reaching the inner service.
pub fn trim_trailing_slash() -> Self {
NormalizePathLayer {}
NormalizePathLayer {
mode: NormalizeMode::Trim,
}
}

/// Create a new [`NormalizePathLayer`].
///
/// Request paths without trailing slash will be appended with a trailing slash. For example, a request with `/foo`
/// will be changed to `/foo/` before reaching the inner service.
pub fn append_trailing_slash() -> Self {
NormalizePathLayer {
mode: NormalizeMode::Append,
}
}
}

impl<S> Layer<S> for NormalizePathLayer {
type Service = NormalizePath<S>;

fn layer(&self, inner: S) -> Self::Service {
NormalizePath::trim_trailing_slash(inner)
NormalizePath {
mode: self.mode,
inner,
}
}
}

Expand All @@ -74,16 +97,25 @@ impl<S> Layer<S> for NormalizePathLayer {
/// See the [module docs](self) for more details.
#[derive(Debug, Copy, Clone)]
pub struct NormalizePath<S> {
mode: NormalizeMode,
inner: S,
}

impl<S> NormalizePath<S> {
/// Create a new [`NormalizePath`].
///
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
/// will be changed to `/foo` before reaching the inner service.
/// Construct a new [`NormalizePath`] with trim mode.
pub fn trim_trailing_slash(inner: S) -> Self {
Self { inner }
Self {
mode: NormalizeMode::Trim,
inner,
}
}

/// Construct a new [`NormalizePath`] with append mode.
pub fn append_trailing_slash(inner: S) -> Self {
Self {
mode: NormalizeMode::Append,
inner,
}
}

define_inner_service_accessors!();
Expand All @@ -103,12 +135,15 @@ where
}

fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
normalize_trailing_slash(req.uri_mut());
match self.mode {
NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()),
NormalizeMode::Append => append_trailing_slash(req.uri_mut()),
}
self.inner.call(req)
}
}

fn normalize_trailing_slash(uri: &mut Uri) {
fn trim_trailing_slash(uri: &mut Uri) {
if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
return;
}
Expand Down Expand Up @@ -137,14 +172,48 @@ fn normalize_trailing_slash(uri: &mut Uri) {
}
}

fn append_trailing_slash(uri: &mut Uri) {
if uri.path().ends_with("/") && !uri.path().ends_with("//") {
return;
}

let trimmed = uri.path().trim_matches('/');
let new_path = if trimmed.is_empty() {
"/".to_string()
} else {
format!("/{trimmed}/")
};

let mut parts = uri.clone().into_parts();

let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
let new_path_and_query = if let Some(query) = path_and_query.query() {
Cow::Owned(format!("{new_path}?{query}"))
} else {
new_path.into()
}
.parse()
.unwrap();

Some(new_path_and_query)
} else {
Some(new_path.parse().unwrap())
};

parts.path_and_query = new_path_and_query;
if let Ok(new_uri) = Uri::from_parts(parts) {
*uri = new_uri;
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};

#[tokio::test]
async fn works() {
async fn trim_works() {
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
Ok(Response::new(request.uri().to_string()))
}
Expand All @@ -168,63 +237,148 @@ mod tests {
#[test]
fn is_noop_if_no_trailing_slash() {
let mut uri = "/foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[test]
fn maintains_query() {
let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn removes_multiple_trailing_slashes() {
let mut uri = "/foo////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[test]
fn removes_multiple_trailing_slashes_even_with_query() {
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn is_noop_on_index() {
let mut uri = "/".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn removes_multiple_trailing_slashes_on_index() {
let mut uri = "////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn removes_multiple_trailing_slashes_on_index_even_with_query() {
let mut uri = "////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/?a=a");
}

#[test]
fn removes_multiple_preceding_slashes_even_with_query() {
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn removes_multiple_preceding_slashes() {
let mut uri = "///foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[tokio::test]
async fn append_works() {
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
Ok(Response::new(request.uri().to_string()))
}

let mut svc = ServiceBuilder::new()
.layer(NormalizePathLayer::append_trailing_slash())
.service_fn(handle);

let body = svc
.ready()
.await
.unwrap()
.call(Request::builder().uri("/foo").body(()).unwrap())
.await
.unwrap()
.into_body();

assert_eq!(body, "/foo/");
}

#[test]
fn is_noop_if_trailing_slash() {
let mut uri = "/foo/".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}

#[test]
fn append_maintains_query() {
let mut uri = "/foo?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_only_keeps_one_slash() {
let mut uri = "/foo////".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}

#[test]
fn append_only_keeps_one_slash_even_with_query() {
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_is_noop_on_index() {
let mut uri = "/".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn append_removes_multiple_trailing_slashes_on_index() {
let mut uri = "////".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn append_removes_multiple_trailing_slashes_on_index_even_with_query() {
let mut uri = "////?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/?a=a");
}

#[test]
fn append_removes_multiple_preceding_slashes_even_with_query() {
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_removes_multiple_preceding_slashes() {
let mut uri = "///foo".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}
}
13 changes: 13 additions & 0 deletions tower-http/src/service_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,19 @@ pub trait ServiceExt {
{
crate::normalize_path::NormalizePath::trim_trailing_slash(self)
}

/// Append trailing slash to paths.
///
/// See [`tower_http::normalize_path`] for more details.
///
/// [`tower_http::normalize_path`]: crate::normalize_path
#[cfg(feature = "normalize-path")]
fn append_trailing_slash(self) -> crate::normalize_path::NormalizePath<Self>
where
Self: Sized,
{
crate::normalize_path::NormalizePath::append_trailing_slash(self)
}
}

impl<T> ServiceExt for T {}
Expand Down