3
3
use std:: str:: FromStr ;
4
4
5
5
use async_std:: io:: { BufReader , Read , Write } ;
6
- use async_std:: prelude:: * ;
6
+ use async_std:: { prelude:: * , sync , task } ;
7
7
use http_types:: headers:: { CONTENT_LENGTH , EXPECT , TRANSFER_ENCODING } ;
8
8
use http_types:: { ensure, ensure_eq, format_err} ;
9
9
use http_types:: { Body , Method , Request , Url } ;
10
10
11
11
use crate :: chunked:: ChunkedDecoder ;
12
+ use crate :: read_notifier:: ReadNotifier ;
12
13
use crate :: { MAX_HEADERS , MAX_HEAD_LENGTH } ;
13
14
14
15
const LF : u8 = b'\n' ;
15
16
16
17
/// The number returned from httparse when the request is HTTP 1.1
17
18
const HTTP_1_1_VERSION : u8 = 1 ;
18
19
20
+ const CONTINUE_HEADER_VALUE : & str = "100-continue" ;
21
+ const CONTINUE_RESPONSE : & [ u8 ] = b"HTTP/1.1 100 Continue\r \n \r \n " ;
22
+
19
23
/// Decode an HTTP request on the server.
20
24
pub async fn decode < IO > ( mut io : IO ) -> http_types:: Result < Option < Request > >
21
25
where
76
80
req. insert_header ( header. name , std:: str:: from_utf8 ( header. value ) ?) ;
77
81
}
78
82
79
- handle_100_continue ( & req, & mut io) . await ?;
80
-
81
83
let content_length = req. header ( CONTENT_LENGTH ) ;
82
84
let transfer_encoding = req. header ( TRANSFER_ENCODING ) ;
83
85
@@ -86,11 +88,24 @@ where
86
88
"Unexpected Content-Length header"
87
89
) ;
88
90
91
+ let ( sender, receiver) = sync:: channel ( 1 ) ;
92
+
93
+ if let Some ( CONTINUE_HEADER_VALUE ) = req. header ( EXPECT ) . map ( |h| h. as_str ( ) ) {
94
+ task:: spawn ( async move {
95
+ if let Ok ( ( ) ) = receiver. recv ( ) . await {
96
+ io. write_all ( CONTINUE_RESPONSE ) . await . ok ( ) ;
97
+ } ;
98
+ } ) ;
99
+ }
100
+
89
101
// Check for Transfer-Encoding
90
102
if let Some ( encoding) = transfer_encoding {
91
103
if encoding. last ( ) . as_str ( ) == "chunked" {
92
104
let trailer_sender = req. send_trailers ( ) ;
93
- let reader = BufReader :: new ( ChunkedDecoder :: new ( reader, trailer_sender) ) ;
105
+ let reader = ReadNotifier :: new (
106
+ BufReader :: new ( ChunkedDecoder :: new ( reader, trailer_sender) ) ,
107
+ sender,
108
+ ) ;
94
109
req. set_body ( Body :: from_reader ( reader, None ) ) ;
95
110
return Ok ( Some ( req) ) ;
96
111
}
@@ -100,7 +115,8 @@ where
100
115
// Check for Content-Length.
101
116
if let Some ( len) = content_length {
102
117
let len = len. last ( ) . as_str ( ) . parse :: < usize > ( ) ?;
103
- req. set_body ( Body :: from_reader ( reader. take ( len as u64 ) , Some ( len) ) ) ;
118
+ let reader = ReadNotifier :: new ( reader. take ( len as u64 ) , sender) ;
119
+ req. set_body ( Body :: from_reader ( reader, Some ( len) ) ) ;
104
120
}
105
121
106
122
Ok ( Some ( req) )
@@ -129,20 +145,6 @@ fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<
129
145
}
130
146
}
131
147
132
- const EXPECT_HEADER_VALUE : & str = "100-continue" ;
133
- const EXPECT_RESPONSE : & [ u8 ] = b"HTTP/1.1 100 Continue\r \n \r \n " ;
134
-
135
- async fn handle_100_continue < IO > ( req : & Request , io : & mut IO ) -> http_types:: Result < ( ) >
136
- where
137
- IO : Write + Unpin ,
138
- {
139
- if let Some ( EXPECT_HEADER_VALUE ) = req. header ( EXPECT ) . map ( |h| h. as_str ( ) ) {
140
- io. write_all ( EXPECT_RESPONSE ) . await ?;
141
- }
142
-
143
- Ok ( ( ) )
144
- }
145
-
146
148
#[ cfg( test) ]
147
149
mod tests {
148
150
use super :: * ;
@@ -207,36 +209,4 @@ mod tests {
207
209
} ,
208
210
)
209
211
}
210
-
211
- #[ test]
212
- fn handle_100_continue_does_nothing_with_no_expect_header ( ) {
213
- let request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
214
- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
215
- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
216
- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
217
- assert ! ( result. is_ok( ) ) ;
218
- }
219
-
220
- #[ test]
221
- fn handle_100_continue_sends_header_if_expects_is_exactly_right ( ) {
222
- let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
223
- request. append_header ( "expect" , "100-continue" ) ;
224
- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
225
- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
226
- assert_eq ! (
227
- std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) ,
228
- "HTTP/1.1 100 Continue\r \n \r \n "
229
- ) ;
230
- assert ! ( result. is_ok( ) ) ;
231
- }
232
-
233
- #[ test]
234
- fn handle_100_continue_does_nothing_if_expects_header_is_wrong ( ) {
235
- let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
236
- request. append_header ( "expect" , "110-extensions-not-allowed" ) ;
237
- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
238
- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
239
- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
240
- assert ! ( result. is_ok( ) ) ;
241
- }
242
212
}
0 commit comments