Skip to content

Commit 5317a82

Browse files
committed
feat: add Limited body
1 parent a97da64 commit 5317a82

File tree

2 files changed

+283
-0
lines changed

2 files changed

+283
-0
lines changed

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
1616
mod empty;
1717
mod full;
18+
mod limited;
1819
mod next;
1920
mod size_hint;
2021

2122
pub mod combinators;
2223

2324
pub use self::empty::Empty;
2425
pub use self::full::Full;
26+
pub use self::limited::{LengthLimitError, Limited};
2527
pub use self::next::{Data, Trailers};
2628
pub use self::size_hint::SizeHint;
2729

src/limited.rs

+281
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
use crate::{Body, SizeHint};
2+
use bytes::Buf;
3+
use http::HeaderMap;
4+
use std::error::Error;
5+
use std::fmt;
6+
use std::pin::Pin;
7+
use std::task::{Context, Poll};
8+
9+
/// A length limited body.
10+
///
11+
/// This body will return an error if more than `N` bytes are returned
12+
/// on polling the wrapped body.
13+
#[derive(Clone, Copy, Debug)]
14+
pub struct Limited<B, const N: usize> {
15+
remaining: usize,
16+
inner: B,
17+
}
18+
19+
impl<B> Limited<B, 0> {
20+
/// Create a new `Limited`.
21+
pub fn new<const N: usize>(inner: B) -> Limited<B, N> {
22+
Limited {
23+
remaining: N,
24+
inner,
25+
}
26+
}
27+
}
28+
29+
impl<B, const N: usize> Default for Limited<B, N>
30+
where
31+
B: Default,
32+
{
33+
fn default() -> Self {
34+
Limited::new(B::default())
35+
}
36+
}
37+
38+
impl<B, const N: usize> Body for Limited<B, N>
39+
where
40+
B: Body + Unpin,
41+
{
42+
type Data = B::Data;
43+
type Error = LengthLimitError<B::Error>;
44+
45+
fn poll_data(
46+
self: Pin<&mut Self>,
47+
cx: &mut Context<'_>,
48+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
49+
let mut this = self;
50+
let res = match Pin::new(&mut this.inner).poll_data(cx) {
51+
Poll::Pending => return Poll::Pending,
52+
Poll::Ready(None) => None,
53+
Poll::Ready(Some(Ok(data))) => {
54+
if data.remaining() > this.remaining {
55+
this.remaining = 0;
56+
// Some(Ok(data))
57+
Some(Err(LengthLimitError::LengthLimitExceeded))
58+
} else {
59+
this.remaining -= data.remaining();
60+
Some(Ok(data))
61+
}
62+
}
63+
Poll::Ready(Some(Err(err))) => Some(Err(LengthLimitError::Other(err))),
64+
};
65+
66+
Poll::Ready(res)
67+
}
68+
69+
fn poll_trailers(
70+
self: Pin<&mut Self>,
71+
cx: &mut Context<'_>,
72+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
73+
let mut this = self;
74+
let res = match Pin::new(&mut this.inner).poll_trailers(cx) {
75+
Poll::Pending => return Poll::Pending,
76+
Poll::Ready(Ok(data)) => Ok(data),
77+
Poll::Ready(Err(err)) => Err(LengthLimitError::Other(err)),
78+
};
79+
80+
Poll::Ready(res)
81+
}
82+
83+
fn is_end_stream(&self) -> bool {
84+
self.inner.is_end_stream()
85+
}
86+
87+
fn size_hint(&self) -> SizeHint {
88+
use std::convert::TryFrom;
89+
match u64::try_from(N) {
90+
Ok(n) => {
91+
let mut hint = self.inner.size_hint();
92+
if hint.lower() >= n {
93+
hint.set_exact(n)
94+
} else if let Some(max) = hint.upper() {
95+
hint.set_upper(n.min(max))
96+
} else {
97+
hint.set_upper(n)
98+
}
99+
hint
100+
}
101+
Err(_) => self.inner.size_hint(),
102+
}
103+
}
104+
}
105+
106+
/// An error returned when reading from a [`Limited`] body.
107+
#[derive(Debug)]
108+
pub enum LengthLimitError<E> {
109+
/// The body exceeded the length limit.
110+
LengthLimitExceeded,
111+
/// Some other error was encountered while reading from the underlying body.
112+
Other(E),
113+
}
114+
115+
impl<E> fmt::Display for LengthLimitError<E>
116+
where
117+
E: fmt::Display,
118+
{
119+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120+
match self {
121+
Self::LengthLimitExceeded => f.write_str("length limit exceeded"),
122+
Self::Other(err) => err.fmt(f),
123+
}
124+
}
125+
}
126+
127+
impl<E> Error for LengthLimitError<E>
128+
where
129+
E: Error,
130+
{
131+
fn source(&self) -> Option<&(dyn Error + 'static)> {
132+
match self {
133+
Self::LengthLimitExceeded => None,
134+
Self::Other(err) => err.source(),
135+
}
136+
}
137+
}
138+
139+
#[cfg(test)]
140+
mod tests {
141+
use super::*;
142+
use crate::Full;
143+
use bytes::Bytes;
144+
use std::convert::Infallible;
145+
146+
#[tokio::test]
147+
async fn read_for_body_under_limit_returns_data() {
148+
const DATA: &[u8] = b"testing";
149+
let inner = Full::new(Bytes::from(DATA));
150+
let body = &mut Limited::new::<8>(inner);
151+
let data = body.data().await.unwrap().unwrap();
152+
assert_eq!(data, DATA);
153+
assert!(matches!(body.data().await, None));
154+
}
155+
156+
#[tokio::test]
157+
async fn read_for_body_over_limit_returns_error() {
158+
const DATA: &[u8] = b"testing a string that is too long";
159+
let inner = Full::new(Bytes::from(DATA));
160+
let body = &mut Limited::new::<8>(inner);
161+
let error = body.data().await.unwrap().unwrap_err();
162+
assert!(matches!(error, LengthLimitError::LengthLimitExceeded));
163+
}
164+
165+
struct Chunky(&'static [&'static [u8]]);
166+
167+
impl Body for Chunky {
168+
type Data = &'static [u8];
169+
type Error = Infallible;
170+
171+
fn poll_data(
172+
self: Pin<&mut Self>,
173+
_cx: &mut Context<'_>,
174+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
175+
let mut this = self;
176+
match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
177+
Some((data, new_tail)) => {
178+
this.0 = new_tail;
179+
180+
Poll::Ready(Some(data))
181+
}
182+
None => Poll::Ready(None),
183+
}
184+
}
185+
186+
fn poll_trailers(
187+
self: Pin<&mut Self>,
188+
_cx: &mut Context<'_>,
189+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
190+
Poll::Ready(Ok(Some(HeaderMap::new())))
191+
}
192+
}
193+
194+
#[tokio::test]
195+
async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
196+
) {
197+
const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
198+
let inner = Chunky(DATA);
199+
let body = &mut Limited::new::<8>(inner);
200+
let data = body.data().await.unwrap().unwrap();
201+
assert_eq!(data, DATA[0]);
202+
let error = body.data().await.unwrap().unwrap_err();
203+
assert!(matches!(error, LengthLimitError::LengthLimitExceeded));
204+
}
205+
206+
#[tokio::test]
207+
async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
208+
const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
209+
let inner = Chunky(DATA);
210+
let body = &mut Limited::new::<8>(inner);
211+
let error = body.data().await.unwrap().unwrap_err();
212+
assert!(matches!(error, LengthLimitError::LengthLimitExceeded));
213+
}
214+
215+
#[tokio::test]
216+
async fn read_for_chunked_body_under_limit_is_okay() {
217+
const DATA: &[&[u8]] = &[b"test", b"ing!"];
218+
let inner = Chunky(DATA);
219+
let body = &mut Limited::new::<8>(inner);
220+
let data = body.data().await.unwrap().unwrap();
221+
assert_eq!(data, DATA[0]);
222+
let data = body.data().await.unwrap().unwrap();
223+
assert_eq!(data, DATA[1]);
224+
assert!(matches!(body.data().await, None));
225+
}
226+
227+
#[tokio::test]
228+
async fn read_for_trailers_propagates_inner_trailers() {
229+
const DATA: &[&[u8]] = &[b"test", b"ing!"];
230+
let inner = Chunky(DATA);
231+
let body = &mut Limited::new::<8>(inner);
232+
let trailers = body.trailers().await.unwrap();
233+
assert_eq!(trailers, Some(HeaderMap::new()))
234+
}
235+
236+
enum ErrorBodyError {
237+
Data,
238+
Trailers,
239+
}
240+
241+
struct ErrorBody;
242+
243+
impl Body for ErrorBody {
244+
type Data = &'static [u8];
245+
type Error = ErrorBodyError;
246+
247+
fn poll_data(
248+
self: Pin<&mut Self>,
249+
_cx: &mut Context<'_>,
250+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
251+
Poll::Ready(Some(Err(ErrorBodyError::Data)))
252+
}
253+
254+
fn poll_trailers(
255+
self: Pin<&mut Self>,
256+
_cx: &mut Context<'_>,
257+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
258+
Poll::Ready(Err(ErrorBodyError::Trailers))
259+
}
260+
}
261+
262+
#[tokio::test]
263+
async fn read_for_body_returning_error_propagates_error() {
264+
let body = &mut Limited::new::<8>(ErrorBody);
265+
let error = body.data().await.unwrap().unwrap_err();
266+
assert!(matches!(
267+
error,
268+
LengthLimitError::Other(ErrorBodyError::Data)
269+
));
270+
}
271+
272+
#[tokio::test]
273+
async fn trailers_for_body_returning_error_propagates_error() {
274+
let body = &mut Limited::new::<8>(ErrorBody);
275+
let error = body.trailers().await.unwrap_err();
276+
assert!(matches!(
277+
error,
278+
LengthLimitError::Other(ErrorBodyError::Trailers)
279+
));
280+
}
281+
}

0 commit comments

Comments
 (0)