Skip to content

Commit 201c2d9

Browse files
authored
Merge branch 'master' into use-alpn-by-default
2 parents 72baf15 + 8790ded commit 201c2d9

File tree

5 files changed

+85
-71
lines changed

5 files changed

+85
-71
lines changed

Cargo.toml

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ homepage = "https://github.com/ctz/hyper-rustls"
1010
repository = "https://github.com/ctz/hyper-rustls"
1111

1212
[dependencies]
13-
bytes = "0.4"
13+
bytes = "0.5.2"
1414
ct-logs = { version = "^0.6.0", optional = true }
15-
futures-util-preview = { version = "0.3.0-alpha.19" }
16-
hyper = { version = "0.13.0-alpha.4", default-features = false, features = ["unstable-stream"] }
15+
futures-util = "0.3.1"
16+
hyper = { version = "0.13.0", default-features = false, features = ["tcp"] }
1717
rustls = "0.16"
18-
tokio-io = { version="0.2.0-alpha.6" }
19-
tokio-rustls = "0.12.0-alpha.4"
18+
tokio = "0.2.4"
19+
tokio-rustls = "0.12.1"
2020
webpki = "^0.21.0"
2121
rustls-native-certs = { version = "^0.1.0", optional = true }
2222

2323
[dev-dependencies]
24-
tokio = "0.2.0-alpha.6"
24+
tokio = { version = "0.2.4", features = ["io-std", "macros", "dns", "stream"] }
2525

2626
[features]
2727
default = ["tokio-runtime"]

examples/client.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
//!
33
//! First parameter is the mandatory URL to GET.
44
//! Second parameter is an optional path to CA store.
5-
use futures_util::TryStreamExt;
6-
use hyper::{client, Body, Chunk, Uri};
5+
use hyper::{body::to_bytes, client, Body, Uri};
76
use std::str::FromStr;
87
use std::{env, fs, io};
98

@@ -74,8 +73,7 @@ async fn run_client() -> io::Result<()> {
7473
println!("Headers:\n{:#?}", res.headers());
7574

7675
let body: Body = res.into_body();
77-
let body: Chunk = body
78-
.try_concat()
76+
let body = to_bytes(body)
7977
.await
8078
.map_err(|e| error(format!("Could not get body: {:?}", e)))?;
8179
println!("Body:\n{}", String::from_utf8_lossy(&body));

examples/server.rs

+20-24
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@
66
//! otherwise HTTP/1.1 will be used.
77
use core::task::{Context, Poll};
88
use futures_util::{
9-
stream::{Stream, StreamExt},
10-
try_future::TryFutureExt,
11-
try_stream::TryStreamExt,
9+
future::TryFutureExt,
10+
stream::{Stream, StreamExt, TryStreamExt},
1211
};
1312
use hyper::service::{make_service_fn, service_fn};
1413
use hyper::{Body, Method, Request, Response, Server, StatusCode};
1514
use rustls::internal::pemfile;
1615
use std::pin::Pin;
17-
use std::{env, fs, io, sync};
1816
use std::vec::Vec;
19-
use tokio::net::tcp::{TcpListener, TcpStream};
17+
use std::{env, fs, io, sync};
18+
use tokio::net::{TcpListener, TcpStream};
2019
use tokio_rustls::server::TlsStream;
2120
use tokio_rustls::TlsAcceptor;
2221

@@ -53,29 +52,26 @@ async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
5352
cfg.set_single_cert(certs, key)
5453
.map_err(|e| error(format!("{}", e)))?;
5554
// Configure ALPN to accept HTTP/2, HTTP/1.1 in that order.
56-
cfg.set_protocols(&[
57-
b"h2".to_vec(),
58-
b"http/1.1".to_vec(),
59-
]);
55+
cfg.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
6056
sync::Arc::new(cfg)
6157
};
6258

6359
// Create a TCP listener via tokio.
64-
let tcp = TcpListener::bind(&addr).await?;
60+
let mut tcp = TcpListener::bind(&addr).await?;
6561
let tls_acceptor = TlsAcceptor::from(tls_cfg);
6662
// Prepare a long-running future stream to accept and serve cients.
67-
let incoming_tls_stream: Pin<Box<dyn Stream<Item = Result<TlsStream<TcpStream>, io::Error>>>> =
68-
tcp.incoming()
69-
.map_err(|e| error(format!("Incoming failed: {:?}", e)))
70-
.and_then(move |s| {
71-
tls_acceptor.accept(s).map_err(|e| {
72-
println!("[!] Voluntary server halt due to client-connection error...");
73-
// Errors could be handled here, instead of server aborting.
74-
// Ok(None)
75-
error(format!("TLS Error: {:?}", e))
76-
})
63+
let incoming_tls_stream = tcp
64+
.incoming()
65+
.map_err(|e| error(format!("Incoming failed: {:?}", e)))
66+
.and_then(move |s| {
67+
tls_acceptor.accept(s).map_err(|e| {
68+
println!("[!] Voluntary server halt due to client-connection error...");
69+
// Errors could be handled here, instead of server aborting.
70+
// Ok(None)
71+
error(format!("TLS Error: {:?}", e))
7772
})
78-
.boxed();
73+
})
74+
.boxed();
7975

8076
let service = make_service_fn(|_| async { Ok::<_, io::Error>(service_fn(echo)) });
8177
let server = Server::builder(HyperAcceptor {
@@ -89,11 +85,11 @@ async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
8985
Ok(())
9086
}
9187

92-
struct HyperAcceptor {
93-
acceptor: Pin<Box<dyn Stream<Item = Result<TlsStream<TcpStream>, io::Error>>>>,
88+
struct HyperAcceptor<'a> {
89+
acceptor: Pin<Box<dyn Stream<Item = Result<TlsStream<TcpStream>, io::Error>> + 'a>>,
9490
}
9591

96-
impl hyper::server::accept::Accept for HyperAcceptor {
92+
impl hyper::server::accept::Accept for HyperAcceptor<'_> {
9793
type Conn = TlsStream<TcpStream>;
9894
type Error = io::Error;
9995

src/connector.rs

+34-34
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
use futures_util::FutureExt;
2-
use hyper::client::connect::{self, Connect};
32
#[cfg(feature = "tokio-runtime")]
4-
use hyper::client::HttpConnector;
5-
use rustls::{ClientConfig, Session};
3+
use hyper::client::connect::HttpConnector;
4+
use hyper::{client::connect::Connection, service::Service, Uri};
5+
use rustls::ClientConfig;
66
use std::future::Future;
77
use std::pin::Pin;
88
use std::sync::Arc;
9+
use std::task::{Context, Poll};
910
use std::{fmt, io};
11+
use tokio::io::{AsyncRead, AsyncWrite};
1012
use tokio_rustls::TlsConnector;
1113
use webpki::DNSNameRef;
1214

1315
use crate::stream::MaybeHttpsStream;
1416

17+
type BoxError = Box<dyn std::error::Error + Send + Sync>;
18+
1519
/// A Connector for the `https` scheme.
1620
#[derive(Clone)]
1721
pub struct HttpsConnector<T> {
@@ -70,59 +74,55 @@ impl<T> From<(T, Arc<ClientConfig>)> for HttpsConnector<T> {
7074
}
7175
}
7276

73-
impl<T> Connect for HttpsConnector<T>
77+
impl<T> Service<Uri> for HttpsConnector<T>
7478
where
75-
T: Connect<Error = io::Error>,
76-
T::Transport: 'static,
77-
T::Future: 'static,
79+
T: Service<Uri>,
80+
T::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static,
81+
T::Future: Send + 'static,
82+
T::Error: Into<BoxError>,
7883
{
79-
type Transport = MaybeHttpsStream<T::Transport>;
80-
type Error = io::Error;
84+
type Response = MaybeHttpsStream<T::Response>;
85+
type Error = BoxError;
8186

8287
#[allow(clippy::type_complexity)]
83-
type Future = Pin<
84-
Box<
85-
dyn Future<
86-
Output = Result<
87-
(MaybeHttpsStream<T::Transport>, connect::Connected),
88-
io::Error,
89-
>,
90-
> + Send,
91-
>,
92-
>;
93-
94-
fn connect(&self, dst: connect::Destination) -> Self::Future {
95-
let is_https = dst.scheme() == "https";
88+
type Future =
89+
Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>;
90+
91+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92+
match self.http.poll_ready(cx) {
93+
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
94+
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
95+
Poll::Pending => Poll::Pending,
96+
}
97+
}
98+
99+
fn call(&mut self, dst: Uri) -> Self::Future {
100+
let is_https = dst.scheme_str() == Some("https");
96101

97102
if !is_https {
98-
let connecting_future = self.http.connect(dst);
103+
let connecting_future = self.http.call(dst);
99104

100105
let f = async move {
101-
let (tcp, conn) = connecting_future.await?;
106+
let tcp = connecting_future.await.map_err(Into::into)?;
102107

103-
Ok((MaybeHttpsStream::Http(tcp), conn))
108+
Ok(MaybeHttpsStream::Http(tcp))
104109
};
105110
f.boxed()
106111
} else {
107112
let cfg = self.tls_config.clone();
108-
let hostname = dst.host().to_string();
109-
let connecting_future = self.http.connect(dst);
113+
let hostname = dst.host().unwrap_or_default().to_string();
114+
let connecting_future = self.http.call(dst);
110115

111116
let f = async move {
112-
let (tcp, conn) = connecting_future.await?;
117+
let tcp = connecting_future.await.map_err(Into::into)?;
113118
let connector = TlsConnector::from(cfg);
114119
let dnsname = DNSNameRef::try_from_ascii_str(&hostname)
115120
.map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid dnsname"))?;
116121
let tls = connector
117122
.connect(dnsname, tcp)
118123
.await
119124
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
120-
let connected = if tls.get_ref().1.get_alpn_protocol() == Some(b"h2") {
121-
conn.negotiated_h2()
122-
} else {
123-
conn
124-
};
125-
Ok((MaybeHttpsStream::Https(tls), connected))
125+
Ok(MaybeHttpsStream::Https(tls))
126126
};
127127
f.boxed()
128128
}

src/stream.rs

+23-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
// Copied from hyperium/hyper-tls#62e3376/src/stream.rs
22
use std::fmt;
33
use std::io;
4+
use std::mem::MaybeUninit;
45
use std::pin::Pin;
56
use std::task::{Context, Poll};
67

7-
use tokio_io::{AsyncRead, AsyncWrite};
8+
use hyper::client::connect::{Connected, Connection};
9+
10+
use rustls::Session;
11+
use tokio::io::{AsyncRead, AsyncWrite};
812
use tokio_rustls::client::TlsStream;
913

1014
/// A stream that might be protected with TLS.
@@ -15,6 +19,22 @@ pub enum MaybeHttpsStream<T> {
1519
Https(TlsStream<T>),
1620
}
1721

22+
impl<T: AsyncRead + AsyncWrite + Connection + Unpin> Connection for MaybeHttpsStream<T> {
23+
fn connected(&self) -> Connected {
24+
match self {
25+
MaybeHttpsStream::Http(s) => s.connected(),
26+
MaybeHttpsStream::Https(s) => {
27+
let (tcp, tls) = s.get_ref();
28+
if tls.get_alpn_protocol() == Some(b"h2") {
29+
tcp.connected().negotiated_h2()
30+
} else {
31+
tcp.connected()
32+
}
33+
}
34+
}
35+
}
36+
}
37+
1838
impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
1939
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2040
match *self {
@@ -38,7 +58,7 @@ impl<T> From<TlsStream<T>> for MaybeHttpsStream<T> {
3858

3959
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeHttpsStream<T> {
4060
#[inline]
41-
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
61+
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
4262
match self {
4363
MaybeHttpsStream::Http(s) => s.prepare_uninitialized_buffer(buf),
4464
MaybeHttpsStream::Https(s) => s.prepare_uninitialized_buffer(buf),
@@ -86,4 +106,4 @@ impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for MaybeHttpsStream<T> {
86106
MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx),
87107
}
88108
}
89-
}
109+
}

0 commit comments

Comments
 (0)