Skip to content

don't send 100-continue until the body has been read from #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ log = "0.4"

[dev-dependencies]
pretty_assertions = "0.6.1"
async-std = { version = "1.4.0", features = ["unstable", "attributes"] }
async-std = { version = "1.6.2", features = ["unstable", "attributes"] }
tempfile = "3.1.0"
async-test = "1.0.0"
duplexify = "1.2.1"
async-dup = "1.2.1"
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ const MAX_HEAD_LENGTH: usize = 8 * 1024;

mod chunked;
mod date;
mod read_notifier;

pub mod client;
pub mod server;
Expand Down
66 changes: 66 additions & 0 deletions src/read_notifier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};

use async_std::io::{self, BufRead, Read};
use async_std::sync::Sender;

pin_project_lite::pin_project! {
/// ReadNotifier forwards [`async_std::io::Read`] and
/// [`async_std::io::BufRead`] to an inner reader. When the
/// ReadNotifier is read from (using `Read`, `ReadExt`, or
/// `BufRead` methods), it sends a single message containing `()`
/// on the channel.
pub(crate) struct ReadNotifier<B> {
#[pin]
reader: B,
sender: Sender<()>,
has_been_read: bool
}
}

impl<B> fmt::Debug for ReadNotifier<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadNotifier")
.field("read", &self.has_been_read)
.finish()
}
}

impl<B: BufRead> ReadNotifier<B> {
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self {
Self {
reader,
sender,
has_been_read: false,
}
}
}

impl<B: BufRead> BufRead for ReadNotifier<B> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
self.project().reader.poll_fill_buf(cx)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().reader.consume(amt)
}
}

impl<B: Read> Read for ReadNotifier<B> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();

if !*this.has_been_read {
if let Ok(()) = this.sender.try_send(()) {
*this.has_been_read = true;
};
}

this.reader.poll_read(cx, buf)
}
}
80 changes: 29 additions & 51 deletions src/server/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
use std::str::FromStr;

use async_std::io::{BufReader, Read, Write};
use async_std::prelude::*;
use async_std::{prelude::*, sync, task};
use http_types::headers::{CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING};
use http_types::{ensure, ensure_eq, format_err};
use http_types::{Body, Method, Request, Url};

use crate::chunked::ChunkedDecoder;
use crate::read_notifier::ReadNotifier;
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};

const LF: u8 = b'\n';

/// The number returned from httparse when the request is HTTP 1.1
const HTTP_1_1_VERSION: u8 = 1;

const CONTINUE_HEADER_VALUE: &str = "100-continue";
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";

/// Decode an HTTP request on the server.
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<Request>>
where
Expand Down Expand Up @@ -76,8 +80,6 @@ where
req.insert_header(header.name, std::str::from_utf8(header.value)?);
}

handle_100_continue(&req, &mut io).await?;

let content_length = req.header(CONTENT_LENGTH);
let transfer_encoding = req.header(TRANSFER_ENCODING);

Expand All @@ -86,11 +88,32 @@ where
"Unexpected Content-Length header"
);

// Establish a channel to wait for the body to be read. This
// allows us to avoid sending 100-continue in situations that
// respond without reading the body, saving clients from uploading
// their body.
let (body_read_sender, body_read_receiver) = sync::channel(1);

if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
task::spawn(async move {
// If the client expects a 100-continue header, spawn a
// task to wait for the first read attempt on the body.
if let Ok(()) = body_read_receiver.recv().await {
io.write_all(CONTINUE_RESPONSE).await.ok();
};
// Since the sender is moved into the Body, this task will
// finish when the client disconnects, whether or not
// 100-continue was sent.
});
}

// Check for Transfer-Encoding
if let Some(encoding) = transfer_encoding {
if encoding.last().as_str() == "chunked" {
let trailer_sender = req.send_trailers();
let reader = BufReader::new(ChunkedDecoder::new(reader, trailer_sender));
let reader = ChunkedDecoder::new(reader, trailer_sender);
let reader = BufReader::new(reader);
let reader = ReadNotifier::new(reader, body_read_sender);
req.set_body(Body::from_reader(reader, None));
return Ok(Some(req));
}
Expand All @@ -100,7 +123,8 @@ where
// Check for Content-Length.
if let Some(len) = content_length {
let len = len.last().as_str().parse::<usize>()?;
req.set_body(Body::from_reader(reader.take(len as u64), Some(len)));
let reader = ReadNotifier::new(reader.take(len as u64), body_read_sender);
req.set_body(Body::from_reader(reader, Some(len)));
}

Ok(Some(req))
Expand Down Expand Up @@ -129,20 +153,6 @@ fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<
}
}

const EXPECT_HEADER_VALUE: &str = "100-continue";
const EXPECT_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";

async fn handle_100_continue<IO>(req: &Request, io: &mut IO) -> http_types::Result<()>
where
IO: Write + Unpin,
{
if let Some(EXPECT_HEADER_VALUE) = req.header(EXPECT).map(|h| h.as_str()) {
io.write_all(EXPECT_RESPONSE).await?;
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -207,36 +217,4 @@ mod tests {
},
)
}

#[test]
fn handle_100_continue_does_nothing_with_no_expect_header() {
let request = Request::new(Method::Get, Url::parse("x:").unwrap());
let mut io = async_std::io::Cursor::new(vec![]);
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
assert!(result.is_ok());
}

#[test]
fn handle_100_continue_sends_header_if_expects_is_exactly_right() {
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
request.append_header("expect", "100-continue");
let mut io = async_std::io::Cursor::new(vec![]);
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
assert_eq!(
std::str::from_utf8(&io.into_inner()).unwrap(),
"HTTP/1.1 100 Continue\r\n\r\n"
);
assert!(result.is_ok());
}

#[test]
fn handle_100_continue_does_nothing_if_expects_header_is_wrong() {
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
request.append_header("expect", "110-extensions-not-allowed");
let mut io = async_std::io::Cursor::new(vec![]);
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
assert!(result.is_ok());
}
}
75 changes: 75 additions & 0 deletions tests/continue.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use async_dup::{Arc, Mutex};
use async_std::io::{Cursor, SeekFrom};
use async_std::{prelude::*, task};
use duplexify::Duplex;
use http_types::Result;
use std::time::Duration;

const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: 10\r\n\
Expect: 100-continue\r\n\r\n";

const SLEEP_DURATION: Duration = std::time::Duration::from_millis(100);
#[async_std::test]
async fn test_with_expect_when_reading_body() -> Result<()> {
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
let server_str: Vec<u8> = vec![];

let mut client = Arc::new(Mutex::new(Cursor::new(client_str)));
let server = Arc::new(Mutex::new(Cursor::new(server_str)));

let mut request = async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written

{
let lock = server.lock();
assert_eq!("", std::str::from_utf8(lock.get_ref())?); //we haven't written yet
};

let mut buf = vec![0u8; 1];
let bytes = request.read(&mut buf).await?; //this triggers the 100-continue even though there's nothing to read yet
assert_eq!(bytes, 0); // normally we'd actually be waiting for the end of the buffer, but this lets us test this sequentially

task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel and io

{
let lock = server.lock();
assert_eq!(
"HTTP/1.1 100 Continue\r\n\r\n",
std::str::from_utf8(lock.get_ref())?
);
};

client.write_all(b"0123456789").await?;
client
.seek(SeekFrom::Start(REQUEST_WITH_EXPECT.len() as u64))
.await?;

assert_eq!("0123456789", request.body_string().await?);

Ok(())
}

#[async_std::test]
async fn test_without_expect_when_not_reading_body() -> Result<()> {
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
let server_str: Vec<u8> = vec![];

let client = Arc::new(Mutex::new(Cursor::new(client_str)));
let server = Arc::new(Mutex::new(Cursor::new(server_str)));

async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel

let server_lock = server.lock();
assert_eq!("", std::str::from_utf8(server_lock.get_ref())?); // we haven't written 100-continue

Ok(())
}