Skip to content

Commit a67b84b

Browse files
committed
feat: add Limited body
1 parent a97da64 commit a67b84b

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
1616
mod empty;
1717
mod full;
18+
mod limited;
1819
mod next;
1920
mod size_hint;
2021

2122
pub mod combinators;
2223

2324
pub use self::empty::Empty;
2425
pub use self::full::Full;
26+
pub use self::limited::{LengthLimitError, Limited};
2527
pub use self::next::{Data, Trailers};
2628
pub use self::size_hint::SizeHint;
2729

src/limited.rs

+280
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
use crate::{Body, SizeHint};
2+
use bytes::Buf;
3+
use http::HeaderMap;
4+
use std::error::Error;
5+
use std::fmt;
6+
use std::pin::Pin;
7+
use std::task::{Context, Poll};
8+
9+
/// A length limited body.
10+
///
11+
/// This body will return an error if more than `N` bytes are returned
12+
/// on polling the wrapped body.
13+
#[derive(Clone, Copy, Debug)]
14+
pub struct Limited<B, const N: usize> {
15+
remaining: usize,
16+
inner: B,
17+
}
18+
19+
impl<B> Limited<B, 0> {
20+
/// Create a new `Limited`.
21+
pub fn new<const N: usize>(inner: B) -> Limited<B, N> {
22+
Limited {
23+
remaining: N,
24+
inner,
25+
}
26+
}
27+
}
28+
29+
impl<B, const N: usize> Default for Limited<B, N>
30+
where
31+
B: Default,
32+
{
33+
fn default() -> Self {
34+
Limited::new(B::default())
35+
}
36+
}
37+
38+
impl<B, const N: usize> Body for Limited<B, N>
39+
where
40+
B: Body + Unpin,
41+
{
42+
type Data = B::Data;
43+
type Error = LengthLimitError<B::Error>;
44+
45+
fn poll_data(
46+
self: Pin<&mut Self>,
47+
cx: &mut Context<'_>,
48+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
49+
let mut this = self;
50+
let res = match Pin::new(&mut this.inner).poll_data(cx) {
51+
Poll::Pending => return Poll::Pending,
52+
Poll::Ready(None) => None,
53+
Poll::Ready(Some(Ok(data))) => {
54+
if data.remaining() > this.remaining {
55+
this.remaining = 0;
56+
Some(Err(LengthLimitError::LengthLimitExceeded))
57+
} else {
58+
this.remaining -= data.remaining();
59+
Some(Ok(data))
60+
}
61+
}
62+
Poll::Ready(Some(Err(err))) => Some(Err(LengthLimitError::Other(err))),
63+
};
64+
65+
Poll::Ready(res)
66+
}
67+
68+
fn poll_trailers(
69+
self: Pin<&mut Self>,
70+
cx: &mut Context<'_>,
71+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
72+
let mut this = self;
73+
let res = match Pin::new(&mut this.inner).poll_trailers(cx) {
74+
Poll::Pending => return Poll::Pending,
75+
Poll::Ready(Ok(data)) => Ok(data),
76+
Poll::Ready(Err(err)) => Err(LengthLimitError::Other(err)),
77+
};
78+
79+
Poll::Ready(res)
80+
}
81+
82+
fn is_end_stream(&self) -> bool {
83+
self.inner.is_end_stream()
84+
}
85+
86+
fn size_hint(&self) -> SizeHint {
87+
use std::convert::TryFrom;
88+
match u64::try_from(N) {
89+
Ok(n) => {
90+
let mut hint = self.inner.size_hint();
91+
if hint.lower() >= n {
92+
hint.set_exact(n)
93+
} else if let Some(max) = hint.upper() {
94+
hint.set_upper(n.min(max))
95+
} else {
96+
hint.set_upper(n)
97+
}
98+
hint
99+
}
100+
Err(_) => self.inner.size_hint(),
101+
}
102+
}
103+
}
104+
105+
/// An error returned when reading from a [`Limited`] body.
106+
#[derive(Debug)]
107+
pub enum LengthLimitError<E> {
108+
/// The body exceeded the length limit.
109+
LengthLimitExceeded,
110+
/// Some other error was encountered while reading from the underlying body.
111+
Other(E),
112+
}
113+
114+
impl<E> fmt::Display for LengthLimitError<E>
115+
where
116+
E: fmt::Display,
117+
{
118+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119+
match self {
120+
Self::LengthLimitExceeded => f.write_str("length limit exceeded"),
121+
Self::Other(err) => err.fmt(f),
122+
}
123+
}
124+
}
125+
126+
impl<E> Error for LengthLimitError<E>
127+
where
128+
E: Error,
129+
{
130+
fn source(&self) -> Option<&(dyn Error + 'static)> {
131+
match self {
132+
Self::LengthLimitExceeded => None,
133+
Self::Other(err) => err.source(),
134+
}
135+
}
136+
}
137+
138+
#[cfg(test)]
139+
mod tests {
140+
use super::*;
141+
use crate::Full;
142+
use bytes::Bytes;
143+
use std::convert::Infallible;
144+
145+
#[tokio::test]
146+
async fn read_for_body_under_limit_returns_data() {
147+
const DATA: &[u8] = b"testing";
148+
let inner = Full::new(Bytes::from(DATA));
149+
let body = &mut Limited::new::<8>(inner);
150+
let data = body.data().await.unwrap().unwrap();
151+
assert_eq!(data, DATA);
152+
assert!(matches!(body.data().await, None));
153+
}
154+
155+
#[tokio::test]
156+
async fn read_for_body_over_limit_returns_error() {
157+
const DATA: &[u8] = b"testing a string that is too long";
158+
let inner = Full::new(Bytes::from(DATA));
159+
let body = &mut Limited::new::<8>(inner);
160+
let error = body.data().await.unwrap().unwrap_err();
161+
assert!(matches!(error, LengthLimitError::LengthLimitExceeded));
162+
}
163+
164+
struct Chunky(&'static [&'static [u8]]);
165+
166+
impl Body for Chunky {
167+
type Data = &'static [u8];
168+
type Error = Infallible;
169+
170+
fn poll_data(
171+
self: Pin<&mut Self>,
172+
_cx: &mut Context<'_>,
173+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
174+
let mut this = self;
175+
match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
176+
Some((data, new_tail)) => {
177+
this.0 = new_tail;
178+
179+
Poll::Ready(Some(data))
180+
}
181+
None => Poll::Ready(None),
182+
}
183+
}
184+
185+
fn poll_trailers(
186+
self: Pin<&mut Self>,
187+
_cx: &mut Context<'_>,
188+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
189+
Poll::Ready(Ok(Some(HeaderMap::new())))
190+
}
191+
}
192+
193+
#[tokio::test]
194+
async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
195+
) {
196+
const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
197+
let inner = Chunky(DATA);
198+
let body = &mut Limited::new::<8>(inner);
199+
let data = body.data().await.unwrap().unwrap();
200+
assert_eq!(data, DATA[0]);
201+
let error = body.data().await.unwrap().unwrap_err();
202+
assert!(matches!(error, LengthLimitError::LengthLimitExceeded));
203+
}
204+
205+
#[tokio::test]
206+
async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
207+
const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
208+
let inner = Chunky(DATA);
209+
let body = &mut Limited::new::<8>(inner);
210+
let error = body.data().await.unwrap().unwrap_err();
211+
assert!(matches!(error, LengthLimitError::LengthLimitExceeded));
212+
}
213+
214+
#[tokio::test]
215+
async fn read_for_chunked_body_under_limit_is_okay() {
216+
const DATA: &[&[u8]] = &[b"test", b"ing!"];
217+
let inner = Chunky(DATA);
218+
let body = &mut Limited::new::<8>(inner);
219+
let data = body.data().await.unwrap().unwrap();
220+
assert_eq!(data, DATA[0]);
221+
let data = body.data().await.unwrap().unwrap();
222+
assert_eq!(data, DATA[1]);
223+
assert!(matches!(body.data().await, None));
224+
}
225+
226+
#[tokio::test]
227+
async fn read_for_trailers_propagates_inner_trailers() {
228+
const DATA: &[&[u8]] = &[b"test", b"ing!"];
229+
let inner = Chunky(DATA);
230+
let body = &mut Limited::new::<8>(inner);
231+
let trailers = body.trailers().await.unwrap();
232+
assert_eq!(trailers, Some(HeaderMap::new()))
233+
}
234+
235+
enum ErrorBodyError {
236+
Data,
237+
Trailers,
238+
}
239+
240+
struct ErrorBody;
241+
242+
impl Body for ErrorBody {
243+
type Data = &'static [u8];
244+
type Error = ErrorBodyError;
245+
246+
fn poll_data(
247+
self: Pin<&mut Self>,
248+
_cx: &mut Context<'_>,
249+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
250+
Poll::Ready(Some(Err(ErrorBodyError::Data)))
251+
}
252+
253+
fn poll_trailers(
254+
self: Pin<&mut Self>,
255+
_cx: &mut Context<'_>,
256+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
257+
Poll::Ready(Err(ErrorBodyError::Trailers))
258+
}
259+
}
260+
261+
#[tokio::test]
262+
async fn read_for_body_returning_error_propagates_error() {
263+
let body = &mut Limited::new::<8>(ErrorBody);
264+
let error = body.data().await.unwrap().unwrap_err();
265+
assert!(matches!(
266+
error,
267+
LengthLimitError::Other(ErrorBodyError::Data)
268+
));
269+
}
270+
271+
#[tokio::test]
272+
async fn trailers_for_body_returning_error_propagates_error() {
273+
let body = &mut Limited::new::<8>(ErrorBody);
274+
let error = body.trailers().await.unwrap_err();
275+
assert!(matches!(
276+
error,
277+
LengthLimitError::Other(ErrorBodyError::Trailers)
278+
));
279+
}
280+
}

0 commit comments

Comments
 (0)