88use std:: { future:: poll_fn, io, pin:: Pin } ;
99
1010use futures_core:: Stream ;
11+ use http:: { HeaderMap , HeaderName , HeaderValue , header} ;
1112use tokio:: io:: { AsyncRead , AsyncWrite , AsyncWriteExt } ;
1213use tokio_util:: codec:: FramedRead ;
1314
@@ -20,12 +21,27 @@ use crate::{
2021/// HTTP/1.1 400 Bad Request response payload.
2122const BAD_REQUEST : & [ u8 ] = b"HTTP/1.1 400 Bad Request\r \n \r \n " ;
2223
24+ /// List of headers added by the server which will cause an error
25+ /// if added by the user:
26+ ///
27+ /// - `host`
28+ /// - `upgrade`
29+ /// - `connection`
30+ /// - `sec-websocket-accept`
31+ pub const DISALLOWED_HEADERS : & [ HeaderName ] = & [
32+ header:: UPGRADE ,
33+ header:: CONNECTION ,
34+ header:: SEC_WEBSOCKET_ACCEPT ,
35+ ] ;
36+
2337/// Builder for WebSocket server connections.
2438pub struct Builder {
2539 /// Configuration for the WebSocket stream.
2640 config : Config ,
2741 /// Limits to impose on the WebSocket stream.
2842 limits : Limits ,
43+ /// Headers to be sent with the switching protocols response.
44+ headers : HeaderMap ,
2945}
3046
3147impl Default for Builder {
@@ -42,6 +58,7 @@ impl Builder {
4258 Self {
4359 config : Config :: default ( ) ,
4460 limits : Limits :: default ( ) ,
61+ headers : HeaderMap :: new ( ) ,
4562 }
4663 }
4764
@@ -61,6 +78,21 @@ impl Builder {
6178 self
6279 }
6380
81+ /// Adds an extra HTTP header to the switching protocols response.
82+ ///
83+ /// # Errors
84+ ///
85+ /// Returns [`Error::DisallowedHeader`] if the header is in
86+ /// the [`DISALLOWED_HEADERS`] list.
87+ pub fn add_header ( mut self , name : HeaderName , value : HeaderValue ) -> Result < Self , Error > {
88+ if DISALLOWED_HEADERS . contains ( & name) {
89+ return Err ( Error :: DisallowedHeader ) ;
90+ }
91+ self . headers . insert ( name, value) ;
92+
93+ Ok ( self )
94+ }
95+
6496 /// Perform a HTTP upgrade handshake on an already established stream and
6597 /// uses it to send and receive WebSocket messages.
6698 ///
@@ -71,12 +103,17 @@ impl Builder {
71103 & self ,
72104 stream : S ,
73105 ) -> Result < ( http:: Request < ( ) > , WebSocketStream < S > ) , Error > {
74- let mut framed = FramedRead :: new ( stream, client_request:: Codec { } ) ;
106+ let mut framed = FramedRead :: new (
107+ stream,
108+ client_request:: Codec {
109+ response_headers : & self . headers ,
110+ } ,
111+ ) ;
75112 let reply = poll_fn ( |cx| Pin :: new ( & mut framed) . poll_next ( cx) ) . await ;
76113
77114 match reply {
78115 Some ( Ok ( ( request, response) ) ) => {
79- framed. get_mut ( ) . write_all ( response. as_bytes ( ) ) . await ?;
116+ framed. get_mut ( ) . write_all ( & response) . await ?;
80117 Ok ( (
81118 request,
82119 WebSocketStream :: from_framed ( framed, Role :: Server , self . config , self . limits ) ,
0 commit comments