Skip to content

Commit c65bdbc

Browse files
committed
Allow adding headers to switching protocols response
Signed-off-by: Jens Reidel <[email protected]>
1 parent baece0c commit c65bdbc

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

src/server.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use std::{future::poll_fn, io, pin::Pin};
99

1010
use futures_core::Stream;
11+
use http::{HeaderMap, HeaderName, HeaderValue, header};
1112
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
1213
use tokio_util::codec::FramedRead;
1314

@@ -20,12 +21,27 @@ use crate::{
2021
/// HTTP/1.1 400 Bad Request response payload.
2122
const BAD_REQUEST: &[u8] = b"HTTP/1.1 400 Bad Request\r\n\r\n";
2223

24+
/// List of headers added by the server which will cause an error
25+
/// if added by the user:
26+
///
27+
/// - `host`
28+
/// - `upgrade`
29+
/// - `connection`
30+
/// - `sec-websocket-accept`
31+
pub const DISALLOWED_HEADERS: &[HeaderName] = &[
32+
header::UPGRADE,
33+
header::CONNECTION,
34+
header::SEC_WEBSOCKET_ACCEPT,
35+
];
36+
2337
/// Builder for WebSocket server connections.
2438
pub struct Builder {
2539
/// Configuration for the WebSocket stream.
2640
config: Config,
2741
/// Limits to impose on the WebSocket stream.
2842
limits: Limits,
43+
/// Headers to be sent with the switching protocols response.
44+
headers: HeaderMap,
2945
}
3046

3147
impl Default for Builder {
@@ -42,6 +58,7 @@ impl Builder {
4258
Self {
4359
config: Config::default(),
4460
limits: Limits::default(),
61+
headers: HeaderMap::new(),
4562
}
4663
}
4764

@@ -61,6 +78,21 @@ impl Builder {
6178
self
6279
}
6380

81+
/// Adds an extra HTTP header to the switching protocols response.
82+
///
83+
/// # Errors
84+
///
85+
/// Returns [`Error::DisallowedHeader`] if the header is in
86+
/// the [`DISALLOWED_HEADERS`] list.
87+
pub fn add_header(mut self, name: HeaderName, value: HeaderValue) -> Result<Self, Error> {
88+
if DISALLOWED_HEADERS.contains(&name) {
89+
return Err(Error::DisallowedHeader);
90+
}
91+
self.headers.insert(name, value);
92+
93+
Ok(self)
94+
}
95+
6496
/// Perform a HTTP upgrade handshake on an already established stream and
6597
/// uses it to send and receive WebSocket messages.
6698
///
@@ -71,12 +103,17 @@ impl Builder {
71103
&self,
72104
stream: S,
73105
) -> Result<(http::Request<()>, WebSocketStream<S>), Error> {
74-
let mut framed = FramedRead::new(stream, client_request::Codec {});
106+
let mut framed = FramedRead::new(
107+
stream,
108+
client_request::Codec {
109+
response_headers: &self.headers,
110+
},
111+
);
75112
let reply = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx)).await;
76113

77114
match reply {
78115
Some(Ok((request, response))) => {
79-
framed.get_mut().write_all(response.as_bytes()).await?;
116+
framed.get_mut().write_all(&response).await?;
80117
Ok((
81118
request,
82119
WebSocketStream::from_framed(framed, Role::Server, self.config, self.limits),

src/upgrade/client_request.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ use std::str::FromStr;
33

44
use base64::{Engine, engine::general_purpose::STANDARD};
55
use bytes::{Buf, BytesMut};
6+
use http::HeaderMap;
67
use httparse::Request;
78
use tokio_util::codec::Decoder;
89

910
use crate::{sha::digest, upgrade::Error};
1011

1112
/// A static HTTP/1.1 101 Switching Protocols response up until the
1213
/// `Sec-WebSocket-Accept` header value.
13-
const SWITCHING_PROTOCOLS_BODY: &str = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
14+
const SWITCHING_PROTOCOLS_BODY: &[u8] = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
1415

1516
/// Returns whether an ASCII byte slice is contained in another one, ignoring
1617
/// captalization.
@@ -94,11 +95,14 @@ impl ClientRequest {
9495
/// It does not implement an [`Encoder`].
9596
///
9697
/// [`Encoder`]: tokio_util::codec::Encoder
97-
pub struct Codec {}
98+
pub struct Codec<'a> {
99+
/// List of headers to add to the Switching Protocols response.
100+
pub response_headers: &'a HeaderMap,
101+
}
98102

99-
impl Decoder for Codec {
103+
impl Decoder for Codec<'_> {
100104
type Error = crate::Error;
101-
type Item = (http::Request<()>, String);
105+
type Item = (http::Request<()>, Vec<u8>);
102106

103107
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
104108
let mut headers = [httparse::EMPTY_HEADER; 64];
@@ -151,11 +155,21 @@ impl Decoder for Codec {
151155

152156
src.advance(request_len);
153157

154-
let mut resp = String::with_capacity(SWITCHING_PROTOCOLS_BODY.len() + ws_accept.len() + 4);
158+
// Preallocate the size without extra headers
159+
let mut resp = Vec::with_capacity(SWITCHING_PROTOCOLS_BODY.len() + ws_accept.len() + 4);
160+
161+
resp.extend_from_slice(SWITCHING_PROTOCOLS_BODY);
162+
resp.extend_from_slice(ws_accept.as_bytes());
163+
resp.extend_from_slice(b"\r\n");
164+
165+
for (name, value) in self.response_headers {
166+
resp.extend_from_slice(name.as_str().as_bytes());
167+
resp.extend_from_slice(b": ");
168+
resp.extend_from_slice(value.as_bytes());
169+
resp.extend_from_slice(b"\r\n");
170+
}
155171

156-
resp.push_str(SWITCHING_PROTOCOLS_BODY);
157-
resp.push_str(&ws_accept);
158-
resp.push_str("\r\n\r\n");
172+
resp.extend_from_slice(b"\r\n");
159173

160174
Ok(Some((request, resp)))
161175
}

0 commit comments

Comments
 (0)