Skip to content

Commit 4c29727

Browse files
committed
don't send 100-continue until the body has been read from
1 parent 9738d53 commit 4c29727

File tree

3 files changed

+87
-51
lines changed

3 files changed

+87
-51
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ const MAX_HEAD_LENGTH: usize = 8 * 1024;
106106

107107
mod chunked;
108108
mod date;
109+
mod read_notifier;
109110

110111
pub mod client;
111112
pub mod server;

src/read_notifier.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use std::{
2+
fmt, io,
3+
pin::Pin,
4+
task::{Context, Poll},
5+
};
6+
7+
use async_std::{
8+
io::{BufRead, Read},
9+
sync::Sender,
10+
};
11+
12+
pin_project_lite::pin_project! {
13+
pub(crate) struct ReadNotifier<B>{
14+
#[pin]
15+
reader: B,
16+
sender: Sender<()>,
17+
read: bool
18+
}
19+
}
20+
21+
impl<B> fmt::Debug for ReadNotifier<B> {
22+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23+
f.debug_struct("ReadNotifier")
24+
.field("read", &self.read)
25+
.finish()
26+
}
27+
}
28+
29+
impl<B: BufRead> ReadNotifier<B> {
30+
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self {
31+
Self {
32+
reader,
33+
sender,
34+
read: false,
35+
}
36+
}
37+
}
38+
39+
impl<B: BufRead> BufRead for ReadNotifier<B> {
40+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
41+
self.project().reader.poll_fill_buf(cx)
42+
}
43+
44+
fn consume(self: Pin<&mut Self>, amt: usize) {
45+
self.project().reader.consume(amt)
46+
}
47+
}
48+
49+
impl<B: Read> Read for ReadNotifier<B> {
50+
fn poll_read(
51+
self: Pin<&mut Self>,
52+
cx: &mut Context<'_>,
53+
buf: &mut [u8],
54+
) -> Poll<io::Result<usize>> {
55+
let this = self.project();
56+
57+
if !*this.read {
58+
if let Ok(()) = this.sender.try_send(()) {
59+
*this.read = true;
60+
};
61+
}
62+
63+
this.reader.poll_read(cx, buf)
64+
}
65+
}

src/server/decode.rs

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@
33
use std::str::FromStr;
44

55
use async_std::io::{BufReader, Read, Write};
6-
use async_std::prelude::*;
6+
use async_std::{prelude::*, sync, task};
77
use http_types::headers::{CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING};
88
use http_types::{ensure, ensure_eq, format_err};
99
use http_types::{Body, Method, Request, Url};
1010

1111
use crate::chunked::ChunkedDecoder;
12+
use crate::read_notifier::ReadNotifier;
1213
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
1314

1415
const LF: u8 = b'\n';
1516

1617
/// The number returned from httparse when the request is HTTP 1.1
1718
const HTTP_1_1_VERSION: u8 = 1;
1819

20+
const CONTINUE_HEADER_VALUE: &str = "100-continue";
21+
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
22+
1923
/// Decode an HTTP request on the server.
2024
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<Request>>
2125
where
@@ -76,8 +80,6 @@ where
7680
req.insert_header(header.name, std::str::from_utf8(header.value)?);
7781
}
7882

79-
handle_100_continue(&req, &mut io).await?;
80-
8183
let content_length = req.header(CONTENT_LENGTH);
8284
let transfer_encoding = req.header(TRANSFER_ENCODING);
8385

@@ -86,11 +88,24 @@ where
8688
"Unexpected Content-Length header"
8789
);
8890

91+
let (sender, receiver) = sync::channel(1);
92+
93+
if let Some(CONTINUE_HEADER_VALUE) = req.header(EXPECT).map(|h| h.as_str()) {
94+
task::spawn(async move {
95+
if let Ok(()) = receiver.recv().await {
96+
io.write_all(CONTINUE_RESPONSE).await.ok();
97+
};
98+
});
99+
}
100+
89101
// Check for Transfer-Encoding
90102
if let Some(encoding) = transfer_encoding {
91103
if encoding.last().as_str() == "chunked" {
92104
let trailer_sender = req.send_trailers();
93-
let reader = BufReader::new(ChunkedDecoder::new(reader, trailer_sender));
105+
let reader = ReadNotifier::new(
106+
BufReader::new(ChunkedDecoder::new(reader, trailer_sender)),
107+
sender,
108+
);
94109
req.set_body(Body::from_reader(reader, None));
95110
return Ok(Some(req));
96111
}
@@ -100,7 +115,8 @@ where
100115
// Check for Content-Length.
101116
if let Some(len) = content_length {
102117
let len = len.last().as_str().parse::<usize>()?;
103-
req.set_body(Body::from_reader(reader.take(len as u64), Some(len)));
118+
let reader = ReadNotifier::new(reader.take(len as u64), sender);
119+
req.set_body(Body::from_reader(reader, Some(len)));
104120
}
105121

106122
Ok(Some(req))
@@ -129,20 +145,6 @@ fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<
129145
}
130146
}
131147

132-
const EXPECT_HEADER_VALUE: &str = "100-continue";
133-
const EXPECT_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
134-
135-
async fn handle_100_continue<IO>(req: &Request, io: &mut IO) -> http_types::Result<()>
136-
where
137-
IO: Write + Unpin,
138-
{
139-
if let Some(EXPECT_HEADER_VALUE) = req.header(EXPECT).map(|h| h.as_str()) {
140-
io.write_all(EXPECT_RESPONSE).await?;
141-
}
142-
143-
Ok(())
144-
}
145-
146148
#[cfg(test)]
147149
mod tests {
148150
use super::*;
@@ -207,36 +209,4 @@ mod tests {
207209
},
208210
)
209211
}
210-
211-
#[test]
212-
fn handle_100_continue_does_nothing_with_no_expect_header() {
213-
let request = Request::new(Method::Get, Url::parse("x:").unwrap());
214-
let mut io = async_std::io::Cursor::new(vec![]);
215-
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
216-
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
217-
assert!(result.is_ok());
218-
}
219-
220-
#[test]
221-
fn handle_100_continue_sends_header_if_expects_is_exactly_right() {
222-
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
223-
request.append_header("expect", "100-continue");
224-
let mut io = async_std::io::Cursor::new(vec![]);
225-
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
226-
assert_eq!(
227-
std::str::from_utf8(&io.into_inner()).unwrap(),
228-
"HTTP/1.1 100 Continue\r\n\r\n"
229-
);
230-
assert!(result.is_ok());
231-
}
232-
233-
#[test]
234-
fn handle_100_continue_does_nothing_if_expects_header_is_wrong() {
235-
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
236-
request.append_header("expect", "110-extensions-not-allowed");
237-
let mut io = async_std::io::Cursor::new(vec![]);
238-
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
239-
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
240-
assert!(result.is_ok());
241-
}
242212
}

0 commit comments

Comments
 (0)