Skip to content

Commit d151822

Browse files
committed
feat(awc): allow to set a specific sni host on the request
1 parent 002c1b5 commit d151822

File tree

12 files changed

+382
-102
lines changed

12 files changed

+382
-102
lines changed

awc/CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
- Update `brotli` dependency to `7`.
66
- Prevent panics on connection pool drop when Tokio runtime is shutdown early.
77
- Minimum supported Rust version (MSRV) is now 1.75.
8+
- Allow to set a specific SNI hostname on the request for TLS connections.
89

910
## 3.5.1
1011

awc/src/builder.rs

+14-8
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ use std::{fmt, net::IpAddr, rc::Rc, time::Duration};
33
use actix_http::{
44
error::HttpError,
55
header::{self, HeaderMap, HeaderName, TryIntoHeaderPair},
6-
Uri,
76
};
87
use actix_rt::net::{ActixStream, TcpStream};
98
use actix_service::{boxed, Service};
109
use base64::prelude::*;
1110

1211
use crate::{
1312
client::{
14-
ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection,
13+
ClientConfig, ConnectInfo, Connector, ConnectorService, HostnameWithSni, TcpConnectError,
14+
TcpConnection,
1515
},
1616
connect::DefaultConnector,
1717
error::SendRequestError,
@@ -46,8 +46,8 @@ impl ClientBuilder {
4646
#[allow(clippy::new_ret_no_self)]
4747
pub fn new() -> ClientBuilder<
4848
impl Service<
49-
ConnectInfo<Uri>,
50-
Response = TcpConnection<Uri, TcpStream>,
49+
ConnectInfo<HostnameWithSni>,
50+
Response = TcpConnection<HostnameWithSni, TcpStream>,
5151
Error = TcpConnectError,
5252
> + Clone,
5353
(),
@@ -69,16 +69,22 @@ impl ClientBuilder {
6969

7070
impl<S, Io, M> ClientBuilder<S, M>
7171
where
72-
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
73-
+ Clone
72+
S: Service<
73+
ConnectInfo<HostnameWithSni>,
74+
Response = TcpConnection<HostnameWithSni, Io>,
75+
Error = TcpConnectError,
76+
> + Clone
7477
+ 'static,
7578
Io: ActixStream + fmt::Debug + 'static,
7679
{
7780
/// Use custom connector service.
7881
pub fn connector<S1, Io1>(self, connector: Connector<S1>) -> ClientBuilder<S1, M>
7982
where
80-
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
81-
+ Clone
83+
S1: Service<
84+
ConnectInfo<HostnameWithSni>,
85+
Response = TcpConnection<HostnameWithSni, Io1>,
86+
Error = TcpConnectError,
87+
> + Clone
8288
+ 'static,
8389
Io1: ActixStream + fmt::Debug + 'static,
8490
{

awc/src/client/connector.rs

+91-35
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,51 @@ use actix_rt::{
1616
use actix_service::Service;
1717
use actix_tls::connect::{
1818
ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection,
19-
Connector as TcpConnector, Resolver,
19+
Connector as TcpConnector, Host, Resolver,
2020
};
2121
use futures_core::{future::LocalBoxFuture, ready};
22-
use http::Uri;
2322
use pin_project_lite::pin_project;
2423

2524
use super::{
2625
config::ConnectorConfig,
2726
connection::{Connection, ConnectionIo},
2827
error::ConnectError,
2928
pool::ConnectionPool,
30-
Connect,
29+
Connect, ServerName,
3130
};
3231

32+
pub enum HostnameWithSni {
33+
ForTcp(String, u16, Option<ServerName>),
34+
ForTls(String, u16, Option<ServerName>),
35+
}
36+
37+
impl Host for HostnameWithSni {
38+
fn hostname(&self) -> &str {
39+
match self {
40+
HostnameWithSni::ForTcp(hostname, _, _) => hostname,
41+
HostnameWithSni::ForTls(hostname, _, sni) => sni.as_deref().unwrap_or(hostname),
42+
}
43+
}
44+
45+
fn port(&self) -> Option<u16> {
46+
match self {
47+
HostnameWithSni::ForTcp(_, port, _) => Some(*port),
48+
HostnameWithSni::ForTls(_, port, _) => Some(*port),
49+
}
50+
}
51+
}
52+
53+
impl HostnameWithSni {
54+
pub fn to_tls(self) -> Self {
55+
match self {
56+
HostnameWithSni::ForTcp(hostname, port, sni) => {
57+
HostnameWithSni::ForTls(hostname, port, sni)
58+
}
59+
HostnameWithSni::ForTls(_, _, _) => self,
60+
}
61+
}
62+
}
63+
3364
enum OurTlsConnector {
3465
#[allow(dead_code)] // only dead when no TLS feature is enabled
3566
None,
@@ -95,8 +126,8 @@ impl Connector<()> {
95126
#[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
96127
pub fn new() -> Connector<
97128
impl Service<
98-
ConnectInfo<Uri>,
99-
Response = TcpConnection<Uri, TcpStream>,
129+
ConnectInfo<HostnameWithSni>,
130+
Response = TcpConnection<HostnameWithSni, TcpStream>,
100131
Error = actix_tls::connect::ConnectError,
101132
> + Clone,
102133
> {
@@ -214,8 +245,11 @@ impl<S> Connector<S> {
214245
pub fn connector<S1, Io1>(self, connector: S1) -> Connector<S1>
215246
where
216247
Io1: ActixStream + fmt::Debug + 'static,
217-
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
218-
+ Clone,
248+
S1: Service<
249+
ConnectInfo<HostnameWithSni>,
250+
Response = TcpConnection<HostnameWithSni, Io1>,
251+
Error = TcpConnectError,
252+
> + Clone,
219253
{
220254
Connector {
221255
connector,
@@ -235,8 +269,11 @@ where
235269
// This remap is to hide ActixStream's trait methods. They are not meant to be called
236270
// from user code.
237271
IO: ActixStream + fmt::Debug + 'static,
238-
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, IO>, Error = TcpConnectError>
239-
+ Clone
272+
S: Service<
273+
ConnectInfo<HostnameWithSni>,
274+
Response = TcpConnection<HostnameWithSni, IO>,
275+
Error = TcpConnectError,
276+
> + Clone
240277
+ 'static,
241278
{
242279
/// Sets TCP connection timeout.
@@ -454,7 +491,7 @@ where
454491
use actix_utils::future::{ready, Ready};
455492

456493
#[allow(non_local_definitions)]
457-
impl IntoConnectionIo for TcpConnection<Uri, Box<dyn ConnectionIo>> {
494+
impl IntoConnectionIo for TcpConnection<HostnameWithSni, Box<dyn ConnectionIo>> {
458495
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
459496
let io = self.into_parts().0;
460497
(io, Protocol::Http2)
@@ -505,7 +542,7 @@ where
505542
use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector};
506543

507544
#[allow(non_local_definitions)]
508-
impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncSslStream<IO>> {
545+
impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncSslStream<IO>> {
509546
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
510547
let sock = self.into_parts().0;
511548
let h2 = sock
@@ -543,7 +580,7 @@ where
543580
use actix_tls::connect::rustls_0_20::{reexports::AsyncTlsStream, TlsConnector};
544581

545582
#[allow(non_local_definitions)]
546-
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
583+
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
547584
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
548585
let sock = self.into_parts().0;
549586
let h2 = sock
@@ -577,7 +614,7 @@ where
577614
use actix_tls::connect::rustls_0_21::{reexports::AsyncTlsStream, TlsConnector};
578615

579616
#[allow(non_local_definitions)]
580-
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
617+
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
581618
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
582619
let sock = self.into_parts().0;
583620
let h2 = sock
@@ -614,7 +651,7 @@ where
614651
use actix_tls::connect::rustls_0_22::{reexports::AsyncTlsStream, TlsConnector};
615652

616653
#[allow(non_local_definitions)]
617-
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
654+
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
618655
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
619656
let sock = self.into_parts().0;
620657
let h2 = sock
@@ -648,7 +685,7 @@ where
648685
use actix_tls::connect::rustls_0_23::{reexports::AsyncTlsStream, TlsConnector};
649686

650687
#[allow(non_local_definitions)]
651-
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
688+
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
652689
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
653690
let sock = self.into_parts().0;
654691
let h2 = sock
@@ -688,15 +725,17 @@ where
688725
}
689726
}
690727

691-
/// tcp service for map `TcpConnection<Uri, Io>` type to `(Io, Protocol)`
728+
/// tcp service for map `TcpConnection<HostnameWithSni, Io>` type to `(Io, Protocol)`
692729
#[derive(Clone)]
693730
pub struct TcpConnectorService<S: Clone> {
694731
service: S,
695732
}
696733

697734
impl<S, Io> Service<Connect> for TcpConnectorService<S>
698735
where
699-
S: Service<Connect, Response = TcpConnection<Uri, Io>, Error = ConnectError> + Clone + 'static,
736+
S: Service<Connect, Response = TcpConnection<HostnameWithSni, Io>, Error = ConnectError>
737+
+ Clone
738+
+ 'static,
700739
{
701740
type Response = (Io, Protocol);
702741
type Error = ConnectError;
@@ -721,7 +760,7 @@ pin_project! {
721760

722761
impl<Fut, Io> Future for TcpConnectorFuture<Fut>
723762
where
724-
Fut: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
763+
Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
725764
{
726765
type Output = Result<(Io, Protocol), ConnectError>;
727766

@@ -767,9 +806,10 @@ struct TlsConnectorService<Tcp, Tls> {
767806
))]
768807
impl<Tcp, Tls, IO> Service<Connect> for TlsConnectorService<Tcp, Tls>
769808
where
770-
Tcp:
771-
Service<Connect, Response = TcpConnection<Uri, IO>, Error = ConnectError> + Clone + 'static,
772-
Tls: Service<TcpConnection<Uri, IO>, Error = std::io::Error> + Clone + 'static,
809+
Tcp: Service<Connect, Response = TcpConnection<HostnameWithSni, IO>, Error = ConnectError>
810+
+ Clone
811+
+ 'static,
812+
Tls: Service<TcpConnection<HostnameWithSni, IO>, Error = std::io::Error> + Clone + 'static,
773813
Tls::Response: IntoConnectionIo,
774814
IO: ConnectionIo,
775815
{
@@ -822,9 +862,14 @@ trait IntoConnectionIo {
822862

823863
impl<S, Io, Fut1, Fut2, Res> Future for TlsConnectorFuture<S, Fut1, Fut2>
824864
where
825-
S: Service<TcpConnection<Uri, Io>, Response = Res, Error = std::io::Error, Future = Fut2>,
865+
S: Service<
866+
TcpConnection<HostnameWithSni, Io>,
867+
Response = Res,
868+
Error = std::io::Error,
869+
Future = Fut2,
870+
>,
826871
S::Response: IntoConnectionIo,
827-
Fut1: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
872+
Fut1: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
828873
Fut2: Future<Output = Result<S::Response, S::Error>>,
829874
Io: ConnectionIo,
830875
{
@@ -838,10 +883,11 @@ where
838883
timeout,
839884
} => {
840885
let res = ready!(fut.poll(cx))?;
886+
let (io, hostname_with_sni) = res.into_parts();
841887
let fut = tls_service
842888
.take()
843889
.expect("TlsConnectorFuture polled after complete")
844-
.call(res);
890+
.call(TcpConnection::new(hostname_with_sni.to_tls(), io));
845891
let timeout = sleep(*timeout);
846892
self.set(TlsConnectorFuture::TlsConnect { fut, timeout });
847893
self.poll(cx)
@@ -875,8 +921,11 @@ impl<S: Clone> TcpConnectorInnerService<S> {
875921

876922
impl<S, Io> Service<Connect> for TcpConnectorInnerService<S>
877923
where
878-
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
879-
+ Clone
924+
S: Service<
925+
ConnectInfo<HostnameWithSni>,
926+
Response = TcpConnection<HostnameWithSni, Io>,
927+
Error = TcpConnectError,
928+
> + Clone
880929
+ 'static,
881930
{
882931
type Response = S::Response;
@@ -886,7 +935,13 @@ where
886935
actix_service::forward_ready!(service);
887936

888937
fn call(&self, req: Connect) -> Self::Future {
889-
let mut req = ConnectInfo::new(req.uri).set_addr(req.addr);
938+
let mut req = ConnectInfo::new(HostnameWithSni::ForTcp(
939+
req.hostname,
940+
req.port,
941+
req.sni_host,
942+
))
943+
.set_addr(req.addr)
944+
.set_port(req.port);
890945

891946
if let Some(local_addr) = self.local_address {
892947
req = req.set_local_addr(local_addr);
@@ -911,9 +966,9 @@ pin_project! {
911966

912967
impl<Fut, Io> Future for TcpConnectorInnerFuture<Fut>
913968
where
914-
Fut: Future<Output = Result<TcpConnection<Uri, Io>, TcpConnectError>>,
969+
Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, TcpConnectError>>,
915970
{
916-
type Output = Result<TcpConnection<Uri, Io>, ConnectError>;
971+
type Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>;
917972

918973
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
919974
let this = self.project();
@@ -973,16 +1028,17 @@ where
9731028
}
9741029

9751030
fn call(&self, req: Connect) -> Self::Future {
976-
match req.uri.scheme_str() {
977-
Some("https") | Some("wss") => match self.tls_pool {
1031+
if req.tls {
1032+
match &self.tls_pool {
9781033
None => ConnectorServiceFuture::SslIsNotSupported,
979-
Some(ref pool) => ConnectorServiceFuture::Tls {
1034+
Some(pool) => ConnectorServiceFuture::Tls {
9801035
fut: pool.call(req),
9811036
},
982-
},
983-
_ => ConnectorServiceFuture::Tcp {
1037+
}
1038+
} else {
1039+
ConnectorServiceFuture::Tcp {
9841040
fut: self.tcp_pool.call(req),
985-
},
1041+
}
9861042
}
9871043
}
9881044
}

0 commit comments

Comments
 (0)