Skip to content

Commit ee24de7

Browse files
committed
Add rustls support
1 parent 70a22e2 commit ee24de7

File tree

12 files changed

+345
-110
lines changed

12 files changed

+345
-110
lines changed

Cargo.lock

Lines changed: 191 additions & 70 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ runtime-actix-native-tls = [ "sqlx-core/runtime-actix-native-tls", "sqlx-macros/
6363
runtime-async-std-native-tls = [ "sqlx-core/runtime-async-std-native-tls", "sqlx-macros/runtime-async-std-native-tls", "_rt-async-std" ]
6464
runtime-tokio-native-tls = [ "sqlx-core/runtime-tokio-native-tls", "sqlx-macros/runtime-tokio-native-tls", "_rt-tokio" ]
6565

66+
runtime-actix-rustls = [ "sqlx-core/runtime-actix-rustls", "sqlx-macros/runtime-actix-rustls", "_rt-actix" ]
67+
runtime-async-std-rustls = [ "sqlx-core/runtime-async-std-rustls", "sqlx-macros/runtime-async-std-rustls", "_rt-async-std" ]
68+
runtime-tokio-rustls = [ "sqlx-core/runtime-tokio-rustls", "sqlx-macros/runtime-tokio-rustls", "_rt-tokio" ]
69+
6670
# for conditional compilation
6771
_rt-actix = []
6872
_rt-async-std = []

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ SQLx is an async, pure Rust<sub>†</sub> SQL crate featuring compile-time check
6666

6767
* **Pure Rust**. The Postgres and MySQL/MariaDB drivers are written in pure Rust using **zero** unsafe<sub>††</sub> code.
6868

69-
* **Runtime Agnostic**. Works on different runtimes ([async-std](https://crates.io/crates/async-std) / [tokio](https://crates.io/crates/tokio) / [actix](https://crates.io/crates/actix-rt)).
69+
* **Runtime Agnostic**. Works on different runtimes ([async-std](https://crates.io/crates/async-std) / [tokio](https://crates.io/crates/tokio) / [actix](https://crates.io/crates/actix-rt)) and TLS backends ([native-tls](https://crates.io/crates/native-tls), [rustls](https://crates.io/crates/rustls)).
7070

7171
<sub><sup>† The SQLite driver uses the libsqlite3 C library as SQLite is an embedded database (the only way
7272
we could be pure Rust for SQLite is by porting _all_ of SQLite to Rust).</sup></sub>
@@ -109,12 +109,14 @@ SQLx is compatible with the [`async-std`], [`tokio`] and [`actix`] runtimes.
109109
[`tokio`]: https://github.com/tokio-rs/tokio
110110
[`actix`]: https://github.com/actix/actix-net
111111

112-
By default, you get `async-std`. If you want a different runtime or TLS backend, just disable the default features and activate the corresponding feature, for example for tokio:
112+
You can also select between [`native-tls`] and [`rustls`] for the TLS backend.
113+
114+
By default, you get `async-std` + `native-tls`. If you want a different runtime or TLS backend, just disable the default features and activate the corresponding feature, for example for tokio + rustls:
113115

114116
```toml
115117
# Cargo.toml
116118
[dependencies]
117-
sqlx = { version = "0.4.0-beta.1", default-features = false, features = [ "runtime-tokio-native-tls", "macros" ] }
119+
sqlx = { version = "0.4.0-beta.1", default-features = false, features = [ "runtime-tokio-rustls", "macros" ] }
118120
```
119121

120122
<sub><sup>The runtime and TLS backend not being separate feature sets to select is a workaround for a [Cargo issue](https://github.com/rust-lang/cargo/issues/3494).</sup></sub>
@@ -123,10 +125,16 @@ sqlx = { version = "0.4.0-beta.1", default-features = false, features = [ "runti
123125

124126
* `runtime-async-std-native-tls` (on by default): Use the `async-std` runtime and `native-tls` TLS backend.
125127

128+
* `runtime-async-std-rustls`: Use the `async-std` runtime and `rustls` TLS backend.
129+
126130
* `runtime-tokio-native-tls`: Use the `tokio` runtime and `native-tls` TLS backend.
127131

132+
* `runtime-tokio-rustls`: Use the `tokio` runtime and `rustls` TLS backend.
133+
128134
* `runtime-actix-native-tls`: Use the `actix` runtime and `native-tls` TLS backend.
129135

136+
* `runtime-actix-rustls`: Use the `actix` runtime and `rustls` TLS backend.
137+
130138
* `postgres`: Add support for the Postgres database server.
131139

132140
* `mysql`: Add support for the MySQL (and MariaDB) database server.

sqlx-bench/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ runtime-actix-native-tls = [ "sqlx/runtime-actix-native-tls", "sqlx-rt/runtime-a
1010
runtime-async-std-native-tls = [ "sqlx/runtime-async-std-native-tls", "sqlx-rt/runtime-async-std-native-tls" ]
1111
runtime-tokio-native-tls = [ "sqlx/runtime-tokio-native-tls", "sqlx-rt/runtime-tokio-native-tls" ]
1212

13+
runtime-actix-rustls = [ "sqlx/runtime-actix-rustls", "sqlx-rt/runtime-actix-rustls" ]
14+
runtime-async-std-rustls = [ "sqlx/runtime-async-std-rustls", "sqlx-rt/runtime-async-std-rustls" ]
15+
runtime-tokio-rustls = [ "sqlx/runtime-tokio-rustls", "sqlx-rt/runtime-tokio-rustls" ]
16+
1317
postgres = ["sqlx/postgres"]
1418

1519
[dependencies]

sqlx-bench/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ You must choose a runtime to execute the benchmarks on; the feature flags are th
2424

2525
```bash
2626
cargo bench --features runtime-tokio-native-tls
27-
cargo bench --features runtime-async-std-native-tls
27+
cargo bench --features runtime-async-std-rustls
2828
```
2929

3030
When complete, the benchmark results will be in `target/criterion/`.

sqlx-core/Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,16 @@ runtime-actix-native-tls = [ "sqlx-rt/runtime-actix-native-tls", "_tls-native-tl
3838
runtime-async-std-native-tls = [ "sqlx-rt/runtime-async-std-native-tls", "_tls-native-tls", "_rt-async-std" ]
3939
runtime-tokio-native-tls = [ "sqlx-rt/runtime-tokio-native-tls", "_tls-native-tls", "_rt-tokio" ]
4040

41+
runtime-actix-rustls = [ "sqlx-rt/runtime-actix-rustls", "_tls-rustls", "_rt-actix" ]
42+
runtime-async-std-rustls = [ "sqlx-rt/runtime-async-std-rustls", "_tls-rustls", "_rt-async-std" ]
43+
runtime-tokio-rustls = [ "sqlx-rt/runtime-tokio-rustls", "_tls-rustls", "_rt-tokio" ]
44+
4145
# for conditional compilation
4246
_rt-actix = []
4347
_rt-async-std = []
4448
_rt-tokio = []
4549
_tls-native-tls = []
50+
_tls-rustls = [ "rustls", "webpki" ]
4651

4752
# support offline/decoupled building (enables serialization of `Describe`)
4853
offline = [ "serde", "either/serde" ]
@@ -86,6 +91,7 @@ parking_lot = "0.11.0"
8691
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
8792
regex = { version = "1.3.9", optional = true }
8893
rsa = { version = "0.3.0", optional = true }
94+
rustls = { version = "0.18.1", optional = true }
8995
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
9096
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
9197
sha-1 = { version = "0.9.0", default-features = false, optional = true }
@@ -96,6 +102,7 @@ time = { version = "0.2.16", optional = true }
96102
smallvec = "1.4.0"
97103
url = { version = "2.1.1", default-features = false }
98104
uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
105+
webpki = { version = "0.21.3", optional = true }
99106
whoami = "0.9.0"
100107
stringprep = "0.1.2"
101108
lru-cache = "0.1.2"

sqlx-core/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ impl From<sqlx_rt::native_tls::Error> for Error {
242242
}
243243
}
244244

245+
#[cfg(feature = "_tls-rustls")]
246+
impl From<webpki::InvalidDNSNameError> for Error {
247+
#[inline]
248+
fn from(error: webpki::InvalidDNSNameError) -> Self {
249+
Error::Tls(Box::new(error))
250+
}
251+
}
252+
245253
// Format an error message as a `Protocol` error
246254
macro_rules! err_protocol {
247255
($expr:expr) => {

sqlx-core/src/net/tls.rs

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@ use std::path::Path;
66
use std::pin::Pin;
77
use std::task::{Context, Poll};
88

9-
use sqlx_rt::{
10-
fs,
11-
native_tls::{Certificate, TlsConnector},
12-
AsyncRead, AsyncWrite, TlsStream,
13-
};
9+
use sqlx_rt::{fs, AsyncRead, AsyncWrite, TlsStream};
1410

1511
use crate::error::Error;
1612
use std::mem::replace;
@@ -40,25 +36,12 @@ where
4036
accept_invalid_hostnames: bool,
4137
root_cert_path: Option<&Path>,
4238
) -> Result<(), Error> {
43-
let mut builder = TlsConnector::builder();
44-
builder
45-
.danger_accept_invalid_certs(accept_invalid_certs)
46-
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
47-
48-
if !accept_invalid_certs {
49-
if let Some(ca) = root_cert_path {
50-
let data = fs::read(ca).await?;
51-
let cert = Certificate::from_pem(&data)?;
52-
53-
builder.add_root_certificate(cert);
54-
}
55-
}
56-
57-
#[cfg(not(feature = "_rt-async-std"))]
58-
let connector = sqlx_rt::TlsConnector::from(builder.build()?);
59-
60-
#[cfg(feature = "_rt-async-std")]
61-
let connector = sqlx_rt::TlsConnector::from(builder);
39+
let connector = configure_tls_connector(
40+
accept_invalid_certs,
41+
accept_invalid_hostnames,
42+
root_cert_path,
43+
)
44+
.await?;
6245

6346
let stream = match replace(self, MaybeTlsStream::Upgrading) {
6447
MaybeTlsStream::Raw(stream) => stream,
@@ -75,12 +58,71 @@ where
7558
}
7659
};
7760

61+
#[cfg(feature = "_tls-rustls")]
62+
let host = webpki::DNSNameRef::try_from_ascii_str(host)?;
63+
7864
*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
7965

8066
Ok(())
8167
}
8268
}
8369

70+
#[cfg(feature = "_tls-native-tls")]
71+
async fn configure_tls_connector(
72+
accept_invalid_certs: bool,
73+
accept_invalid_hostnames: bool,
74+
root_cert_path: Option<&Path>,
75+
) -> Result<sqlx_rt::TlsConnector, Error> {
76+
use sqlx_rt::native_tls::{Certificate, TlsConnector};
77+
78+
let mut builder = TlsConnector::builder();
79+
builder
80+
.danger_accept_invalid_certs(accept_invalid_certs)
81+
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
82+
83+
if !accept_invalid_certs {
84+
if let Some(ca) = root_cert_path {
85+
let data = fs::read(ca).await?;
86+
let cert = Certificate::from_pem(&data)?;
87+
88+
builder.add_root_certificate(cert);
89+
}
90+
}
91+
92+
#[cfg(not(feature = "_rt-async-std"))]
93+
let connector = builder.build()?.into();
94+
95+
#[cfg(feature = "_rt-async-std")]
96+
let connector = builder.into();
97+
98+
Ok(connector)
99+
}
100+
101+
#[cfg(feature = "_tls-rustls")]
102+
async fn configure_tls_connector(
103+
_accept_invalid_certs: bool,
104+
_accept_invalid_hostnames: bool,
105+
root_cert_path: Option<&Path>,
106+
) -> Result<sqlx_rt::TlsConnector, Error> {
107+
// FIXME: Support accept_invalid_certs / accept_invalid_hostnames
108+
109+
use rustls::ClientConfig;
110+
use std::io::Cursor;
111+
use std::sync::Arc;
112+
113+
let mut config = ClientConfig::new();
114+
115+
if let Some(ca) = root_cert_path {
116+
let data = fs::read(ca).await?;
117+
let mut cursor = Cursor::new(data);
118+
config.root_store.add_pem_file(&mut cursor).map_err(|_| {
119+
Error::Tls(format!("Invalid certificate file: {}", ca.display()).into())
120+
})?;
121+
}
122+
123+
Ok(Arc::new(config).into())
124+
}
125+
84126
impl<S> AsyncRead for MaybeTlsStream<S>
85127
where
86128
S: Unpin + AsyncWrite + AsyncRead,
@@ -192,12 +234,15 @@ where
192234
match self {
193235
MaybeTlsStream::Raw(s) => s,
194236

195-
#[cfg(not(feature = "_rt-async-std"))]
196-
MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
237+
#[cfg(feature = "_tls-rustls")]
238+
MaybeTlsStream::Tls(s) => s.get_ref().0,
197239

198-
#[cfg(feature = "_rt-async-std")]
240+
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
199241
MaybeTlsStream::Tls(s) => s.get_ref(),
200242

243+
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
244+
MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
245+
201246
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
202247
}
203248
}
@@ -211,12 +256,15 @@ where
211256
match self {
212257
MaybeTlsStream::Raw(s) => s,
213258

214-
#[cfg(not(feature = "_rt-async-std"))]
215-
MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
259+
#[cfg(feature = "_tls-rustls")]
260+
MaybeTlsStream::Tls(s) => s.get_mut().0,
216261

217-
#[cfg(feature = "_rt-async-std")]
262+
#[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
218263
MaybeTlsStream::Tls(s) => s.get_mut(),
219264

265+
#[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
266+
MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
267+
220268
MaybeTlsStream::Upgrading => panic!(io::Error::from(io::ErrorKind::ConnectionAborted)),
221269
}
222270
}

sqlx-macros/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ runtime-actix-native-tls = [ "sqlx-core/runtime-actix-native-tls", "sqlx-rt/runt
2424
runtime-async-std-native-tls = [ "sqlx-core/runtime-async-std-native-tls", "sqlx-rt/runtime-async-std-native-tls", "_rt-async-std" ]
2525
runtime-tokio-native-tls = [ "sqlx-core/runtime-tokio-native-tls", "sqlx-rt/runtime-tokio-native-tls", "_rt-tokio" ]
2626

27+
runtime-actix-rustls = [ "sqlx-core/runtime-actix-rustls", "sqlx-rt/runtime-actix-rustls", "_rt-actix" ]
28+
runtime-async-std-rustls = [ "sqlx-core/runtime-async-std-rustls", "sqlx-rt/runtime-async-std-rustls", "_rt-async-std" ]
29+
runtime-tokio-rustls = [ "sqlx-core/runtime-tokio-rustls", "sqlx-rt/runtime-tokio-rustls", "_rt-tokio" ]
30+
2731
# for conditional compilation
2832
_rt-actix = []
2933
_rt-async-std = []

sqlx-rt/Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,25 @@ runtime-actix-native-tls = [ "_rt-actix", "_tls-native-tls", "tokio-native-tls"
1515
runtime-async-std-native-tls = [ "_rt-async-std", "_tls-native-tls", "async-native-tls" ]
1616
runtime-tokio-native-tls = [ "_rt-tokio", "_tls-native-tls", "tokio-native-tls" ]
1717

18+
runtime-actix-rustls = [ "_rt-actix", "_tls-rustls", "tokio-rustls" ]
19+
runtime-async-std-rustls = [ "_rt-async-std", "_tls-rustls", "async-rustls" ]
20+
runtime-tokio-rustls = [ "_rt-tokio", "_tls-rustls", "tokio-rustls" ]
21+
1822
# Not used directly and not re-exported from sqlx
1923
_rt-actix = [ "actix-rt", "actix-threadpool", "tokio", "once_cell" ]
2024
_rt-async-std = [ "async-std" ]
2125
_rt-tokio = [ "tokio", "once_cell" ]
2226
_tls-native-tls = [ "native-tls" ]
27+
_tls-rustls = [ ]
2328

2429
[dependencies]
2530
async-native-tls = { version = "0.3.3", optional = true }
31+
async-rustls = { version = "0.1.1", optional = true }
2632
actix-rt = { version = "1.1.1", optional = true }
2733
actix-threadpool = { version = "0.3.2", optional = true }
28-
async-std = { version = "1.6.0", features = [ "unstable" ], optional = true }
34+
async-std = { version = "1.6.5", features = [ "unstable" ], optional = true }
2935
tokio = { version = "0.2.21", optional = true, features = [ "blocking", "stream", "fs", "tcp", "uds", "macros", "rt-core", "rt-threaded", "time", "dns", "io-util" ] }
3036
tokio-native-tls = { version = "0.1.0", optional = true }
37+
tokio-rustls = { version = "0.14.0", optional = true }
3138
native-tls = { version = "0.2.4", optional = true }
3239
once_cell = { version = "1.4", features = ["std"], optional = true }

sqlx-rt/src/lib.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,26 @@
22
feature = "runtime-actix-native-tls",
33
feature = "runtime-async-std-native-tls",
44
feature = "runtime-tokio-native-tls",
5+
feature = "runtime-actix-rustls",
6+
feature = "runtime-async-std-rustls",
7+
feature = "runtime-tokio-rustls",
58
)))]
69
compile_error!(
710
"one of the features ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \
8-
'runtime-tokio-native-tls'] must be enabled"
11+
'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \
12+
'runtime-tokio-rustls'] must be enabled"
913
);
1014

1115
#[cfg(any(
1216
all(feature = "_rt-actix", feature = "_rt-async-std"),
1317
all(feature = "_rt-actix", feature = "_rt-tokio"),
1418
all(feature = "_rt-async-std", feature = "_rt-tokio"),
19+
all(feature = "_tls-native-tls", feature = "_tls-rustls"),
1520
))]
1621
compile_error!(
1722
"only one of ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \
18-
'runtime-tokio-native-tls'] can be enabled"
23+
'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \
24+
'runtime-tokio-rustls'] can be enabled"
1925
);
2026

2127
#[cfg(all(feature = "_tls-native-tls"))]
@@ -78,10 +84,17 @@ mod tokio_runtime {
7884
#[cfg(all(
7985
feature = "_tls-native-tls",
8086
any(feature = "_rt-tokio", feature = "_rt-actix"),
81-
not(feature = "_rt-async-std"),
87+
not(any(feature = "_tls-rustls", feature = "_rt-async-std")),
8288
))]
8389
pub use tokio_native_tls::{TlsConnector, TlsStream};
8490

91+
#[cfg(all(
92+
feature = "_tls-rustls",
93+
any(feature = "_rt-tokio", feature = "_rt-actix"),
94+
not(any(feature = "_tls-native-tls", feature = "_rt-async-std")),
95+
))]
96+
pub use tokio_rustls::{client::TlsStream, TlsConnector};
97+
8598
//
8699
// tokio
87100
//
@@ -170,3 +183,14 @@ where
170183

171184
#[cfg(all(feature = "async-native-tls", not(feature = "tokio-native-tls")))]
172185
pub use async_native_tls::{TlsConnector, TlsStream};
186+
187+
#[cfg(all(
188+
feature = "_tls-rustls",
189+
feature = "_rt-async-std",
190+
not(any(
191+
feature = "_tls-native-tls",
192+
feature = "_rt-tokio",
193+
feature = "_rt-actix"
194+
)),
195+
))]
196+
pub use async_rustls::{client::TlsStream, TlsConnector};

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
))]
88
compile_error!(
99
"the features 'runtime-actix', 'runtime-async-std' and 'runtime-tokio' have been removed in
10-
favor of new features 'runtime-{rt}-{tls}' where rt is one of 'actix', 'async-std' and
11-
'tokio'."
10+
favor of new features 'runtime-{rt}-{tls}' where rt is one of 'actix', 'async-std' and 'tokio'
11+
and 'tls' is one of 'native-tls' and 'rustls'."
1212
);
1313

1414
pub use sqlx_core::acquire::Acquire;

0 commit comments

Comments
 (0)