Skip to content

Extract hyper-independent part of rust-websocket into a separate crate. #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "websocket"
version = "0.23.0"
version = "0.24.0"
authors = ["cyderize <[email protected]>", "Michael Eden <[email protected]>"]

description = "A WebSocket (RFC6455) library for Rust."
Expand All @@ -19,11 +19,7 @@ license = "MIT"
hyper = "^0.10.6"
unicase = "1.0"
url = "1.0"
bitflags = "1.0.4"
rand = "0.6.1"
byteorder = "1.0"
sha1 = "0.6"
base64 = "0.10.0"
futures = { version = "0.1", optional = true }
tokio-io = { version = "0.1", optional = true }
tokio-tls = { version = "0.2.0", optional = true }
Expand All @@ -32,6 +28,7 @@ tokio-codec = { version = "0.1", optional = true }
tokio-reactor = { version = "0.1", optional = true }
bytes = { version = "0.4", optional = true }
native-tls = { version = "0.2.1", optional = true }
websocket-base = { path = "websocket-base", version="0.24.0", default-features=false }

[dev-dependencies]
futures-cpupool = "0.1"
Expand All @@ -43,8 +40,13 @@ features = ["codec", "tcp", "rt-full"]

[features]
default = ["sync", "sync-ssl", "async", "async-ssl"]
sync = []
sync-ssl = ["native-tls", "sync"]
async = ["bytes", "futures", "tokio-io", "tokio-tcp", "tokio-reactor", "tokio-codec"]
async-ssl = ["native-tls", "tokio-tls", "async"]
sync = ["websocket-base/sync"]
sync-ssl = ["native-tls", "sync", "websocket-base/sync-ssl"]
async = ["bytes", "futures", "tokio-io", "tokio-tcp", "tokio-reactor", "tokio-codec", "websocket-base/async"]
async-ssl = ["native-tls", "tokio-tls", "async", "websocket-base/async-ssl"]
nightly = ["hyper/nightly"]

[workspace]
members = [
"websocket-base"
]
50 changes: 29 additions & 21 deletions src/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mod common_imports {
pub use hyper::method::Method;
pub use hyper::status::StatusCode;
pub use hyper::uri::RequestUri;
pub use result::{WSUrlErrorKind, WebSocketError, WebSocketResult};
pub use result::{WSUrlErrorKind, WebSocketError, WebSocketOtherError, WebSocketResult};
pub use std::net::TcpStream;
pub use std::net::ToSocketAddrs;
pub use stream::{self, Stream};
Expand Down Expand Up @@ -58,6 +58,8 @@ mod async_imports {
#[cfg(feature = "async")]
use self::async_imports::*;

use result::towse;

/// Build clients with a builder-style API
/// This makes it easy to create and configure a websocket
/// connection:
Expand Down Expand Up @@ -289,7 +291,7 @@ impl<'u> ClientBuilder<'u> {
/// Use this only if you know what you're doing, and this almost
/// never has to be used.
pub fn key(mut self, key: [u8; 16]) -> Self {
self.headers.set(WebSocketKey(key));
self.headers.set(WebSocketKey::from_array(key));
self.key_set = true;
self
}
Expand Down Expand Up @@ -484,7 +486,7 @@ impl<'u> ClientBuilder<'u> {

// wait for a response
let mut reader = BufReader::new(stream);
let response = parse_response(&mut reader)?;
let response = parse_response(&mut reader).map_err(towse)?;

// validate
self.validate(&response)?;
Expand Down Expand Up @@ -563,7 +565,7 @@ impl<'u> ClientBuilder<'u> {
};
// secure connection, wrap with ssl
let future = tcp_stream
.and_then(move |s| connector.connect(&host, s).map_err(Into::into))
.and_then(move |s| connector.connect(&host, s).map_err(towse))
.and_then(move |stream| {
let stream: Box<stream::async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
Expand Down Expand Up @@ -639,7 +641,7 @@ impl<'u> ClientBuilder<'u> {

// put it all together
let future = tcp_stream
.and_then(move |s| connector.connect(&host, s).map_err(Into::into))
.and_then(move |s| connector.connect(&host, s).map_err(towse))
.and_then(move |stream| builder.async_connect_on(stream));
Box::new(future)
}
Expand Down Expand Up @@ -760,7 +762,7 @@ impl<'u> ClientBuilder<'u> {
.send(request)
.map_err(::std::convert::Into::into)
// wait for a response
.and_then(|stream| stream.into_future().map_err(|e| e.0.into()))
.and_then(|stream| stream.into_future().map_err(|e| towse(e.0)))
// validate
.and_then(move |(message, stream)| {
message
Expand Down Expand Up @@ -793,9 +795,10 @@ impl<'u> ClientBuilder<'u> {
Some(a) => a,
None => {
return Box::new(
Err(WebSocketError::WebSocketUrlError(
Err(WebSocketOtherError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
))
.map_err(towse)
.into_future(),
);
}
Expand Down Expand Up @@ -855,20 +858,21 @@ impl<'u> ClientBuilder<'u> {
let status = StatusCode::from_u16(response.subject.0);

if status != StatusCode::SwitchingProtocols {
return Err(WebSocketError::StatusCodeError(status));
return Err(WebSocketOtherError::StatusCodeError(status)).map_err(towse);
}

let key = self
.headers
.get::<WebSocketKey>()
.ok_or(WebSocketError::RequestError(
.ok_or(WebSocketOtherError::RequestError(
"Request Sec-WebSocket-Key was invalid",
))?;

if response.headers.get() != Some(&(WebSocketAccept::new(key))) {
return Err(WebSocketError::ResponseError(
return Err(WebSocketOtherError::ResponseError(
"Sec-WebSocket-Accept is invalid",
));
))
.map_err(towse);
}

if response.headers.get()
Expand All @@ -878,9 +882,10 @@ impl<'u> ClientBuilder<'u> {
version: None,
}])),
) {
return Err(WebSocketError::ResponseError(
return Err(WebSocketOtherError::ResponseError(
"Upgrade field must be WebSocket",
));
))
.map_err(towse);
}

if self.headers.get()
Expand All @@ -889,9 +894,10 @@ impl<'u> ClientBuilder<'u> {
"Upgrade".to_string(),
))])),
) {
return Err(WebSocketError::ResponseError(
return Err(WebSocketOtherError::ResponseError(
"Connection field must be 'Upgrade'",
));
))
.map_err(towse);
}

Ok(())
Expand All @@ -909,9 +915,10 @@ impl<'u> ClientBuilder<'u> {
#[cfg(any(feature = "sync", feature = "async"))]
fn extract_host_port(&self, secure: Option<bool>) -> WebSocketResult<::url::HostAndPort<&str>> {
if self.url.host().is_none() {
return Err(WebSocketError::WebSocketUrlError(
return Err(WebSocketOtherError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
));
))
.map_err(towse);
}

Ok(self.url.with_default_port(|url| {
Expand Down Expand Up @@ -941,14 +948,15 @@ impl<'u> ClientBuilder<'u> {
let host = match self.url.host_str() {
Some(h) => h,
None => {
return Err(WebSocketError::WebSocketUrlError(
return Err(WebSocketOtherError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
));
))
.map_err(towse);
}
};
let connector = match connector {
Some(c) => c,
None => TlsConnector::builder().build()?,
None => TlsConnector::builder().build().map_err(towse)?,
};
Ok((host, connector))
}
Expand All @@ -960,7 +968,7 @@ impl<'u> ClientBuilder<'u> {
connector: Option<TlsConnector>,
) -> WebSocketResult<TlsStream<TcpStream>> {
let (host, connector) = self.extract_host_ssl_conn(connector)?;
let ssl_stream = connector.connect(host, tcp_stream)?;
let ssl_stream = connector.connect(host, tcp_stream).map_err(towse)?;
Ok(ssl_stream)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
//! See it's module level documentation for more info.

pub mod http;
pub mod ws;
pub use websocket_base::codec::ws;
36 changes: 6 additions & 30 deletions src/header/accept.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,39 @@
use base64;
use header::WebSocketKey;
use hyper;
use hyper::header::parsing::from_one_raw_str;
use hyper::header::{Header, HeaderFormat};
use result::{WebSocketError, WebSocketResult};
use sha1::Sha1;
use std::fmt::{self, Debug};
use std::str::FromStr;

static MAGIC_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
use websocket_base::header::WebSocketAccept as WebSocketAcceptLL;

/// Represents a Sec-WebSocket-Accept header
#[derive(PartialEq, Clone, Copy)]
pub struct WebSocketAccept([u8; 20]);
pub struct WebSocketAccept(WebSocketAcceptLL);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? pub use websocket_lowlevel::header::WebSocketAccept would be enough

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not implement hyper_0.10's Header and HeaderFormat traits.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see


impl Debug for WebSocketAccept {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WebSocketAccept({})", self.serialize())
self.0.fmt(f)
}
}

impl FromStr for WebSocketAccept {
type Err = WebSocketError;

fn from_str(accept: &str) -> WebSocketResult<WebSocketAccept> {
match base64::decode(accept) {
Ok(vec) => {
if vec.len() != 20 {
return Err(WebSocketError::ProtocolError(
"Sec-WebSocket-Accept must be 20 bytes",
));
}
let mut array = [0u8; 20];
array[..20].clone_from_slice(&vec[..20]);
Ok(WebSocketAccept(array))
}
Err(_) => Err(WebSocketError::ProtocolError(
"Invalid Sec-WebSocket-Accept",
)),
}
Ok(WebSocketAccept(WebSocketAcceptLL::from_str(accept)?))
}
}

impl WebSocketAccept {
/// Create a new WebSocketAccept from the given WebSocketKey
pub fn new(key: &WebSocketKey) -> WebSocketAccept {
let serialized = key.serialize();
let mut concat_key = String::with_capacity(serialized.len() + 36);
concat_key.push_str(&serialized[..]);
concat_key.push_str(MAGIC_GUID);
let mut sha1 = Sha1::new();
sha1.update(concat_key.as_bytes());
let bytes = sha1.digest().bytes();
WebSocketAccept(bytes)
WebSocketAccept(WebSocketAcceptLL::new(&key.0))
}
/// Return the Base64 encoding of this WebSocketAccept
pub fn serialize(&self) -> String {
let WebSocketAccept(accept) = *self;
base64::encode(&accept)
self.0.serialize()
}
}

Expand Down
37 changes: 13 additions & 24 deletions src/header/key.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,43 @@
use base64;
use hyper;
use hyper::header::parsing::from_one_raw_str;
use hyper::header::{Header, HeaderFormat};
use rand;
use result::{WebSocketError, WebSocketResult};
use std::fmt::{self, Debug};
use std::str::FromStr;

use websocket_base::header::WebSocketKey as WebSocketKeyLL;

/// Represents a Sec-WebSocket-Key header.
#[derive(PartialEq, Clone, Copy, Default)]
pub struct WebSocketKey(pub [u8; 16]);
pub struct WebSocketKey(pub WebSocketKeyLL);

impl Debug for WebSocketKey {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WebSocketKey({})", self.serialize())
self.0.fmt(f)
}
}

impl FromStr for WebSocketKey {
type Err = WebSocketError;

fn from_str(key: &str) -> WebSocketResult<WebSocketKey> {
match base64::decode(key) {
Ok(vec) => {
if vec.len() != 16 {
return Err(WebSocketError::ProtocolError(
"Sec-WebSocket-Key must be 16 bytes",
));
}
let mut array = [0u8; 16];
array[..16].clone_from_slice(&vec[..16]);
Ok(WebSocketKey(array))
}
Err(_) => Err(WebSocketError::ProtocolError(
"Invalid Sec-WebSocket-Accept",
)),
}
Ok(WebSocketKey(WebSocketKeyLL::from_str(key)?))
}
}

impl WebSocketKey {
/// Generate a new, random WebSocketKey
pub fn new() -> WebSocketKey {
let key = rand::random();
WebSocketKey(key)
WebSocketKey(WebSocketKeyLL::new())
}
/// Return the Base64 encoding of this WebSocketKey
pub fn serialize(&self) -> String {
let WebSocketKey(key) = *self;
base64::encode(&key)
self.0.serialize()
}

/// Create WebSocketKey by explicitly specifying the key
pub fn from_array(a: [u8; 16]) -> WebSocketKey {
WebSocketKey(WebSocketKeyLL(a))
}
}

Expand Down Expand Up @@ -78,7 +67,7 @@ mod tests {
fn test_header_key() {
use header::Headers;

let extensions = WebSocketKey([65; 16]);
let extensions = WebSocketKey::from_array([65; 16]);
let mut headers = Headers::new();
headers.set(extensions);

Expand Down
Loading