Skip to content

Commit 2eb7c06

Browse files
committed
De-duplicate mysql & postgres TLS code
1 parent f985eec commit 2eb7c06

File tree

3 files changed

+54
-60
lines changed

3 files changed

+54
-60
lines changed

sqlx-core/src/mysql/connection/tls.rs

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
use sqlx_rt::{
2-
fs,
3-
native_tls::{Certificate, TlsConnector},
4-
};
5-
61
use crate::error::Error;
72
use crate::mysql::connection::MySqlStream;
83
use crate::mysql::protocol::connect::SslRequest;
@@ -46,34 +41,20 @@ async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Res
4641

4742
stream.flush().await?;
4843

49-
// FIXME: de-duplicate with postgres/connection/tls.rs
50-
5144
let accept_invalid_certs = !matches!(
5245
options.ssl_mode,
5346
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
5447
);
55-
56-
let mut builder = TlsConnector::builder();
57-
builder
58-
.danger_accept_invalid_certs(accept_invalid_certs)
59-
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity));
60-
61-
if !accept_invalid_certs {
62-
if let Some(ca) = &options.ssl_ca {
63-
let data = fs::read(ca).await?;
64-
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
65-
66-
builder.add_root_certificate(cert);
67-
}
68-
}
69-
70-
#[cfg(not(feature = "_rt-async-std"))]
71-
let connector = builder.build().map_err(Error::tls)?;
72-
73-
#[cfg(feature = "_rt-async-std")]
74-
let connector = builder;
75-
76-
stream.upgrade(&options.host, connector.into()).await?;
48+
let accept_invalid_host_names = !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity);
49+
50+
stream
51+
.upgrade(
52+
&options.host,
53+
accept_invalid_certs,
54+
accept_invalid_host_names,
55+
options.ssl_ca.as_deref(),
56+
)
57+
.await?;
7758

7859
Ok(true)
7960
}

sqlx-core/src/net/tls.rs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22

33
use std::io;
44
use std::ops::{Deref, DerefMut};
5+
use std::path::Path;
56
use std::pin::Pin;
67
use std::task::{Context, Poll};
78

8-
use sqlx_rt::{AsyncRead, AsyncWrite, TlsConnector, TlsStream};
9+
use sqlx_rt::{
10+
fs,
11+
native_tls::{Certificate, TlsConnector},
12+
AsyncRead, AsyncWrite, TlsStream,
13+
};
914

1015
use crate::error::Error;
1116
use std::mem::replace;
@@ -28,7 +33,33 @@ where
2833
matches!(self, Self::Tls(_))
2934
}
3035

31-
pub async fn upgrade(&mut self, host: &str, connector: TlsConnector) -> Result<(), Error> {
36+
pub async fn upgrade(
37+
&mut self,
38+
host: &str,
39+
accept_invalid_certs: bool,
40+
accept_invalid_hostnames: bool,
41+
root_cert_path: Option<&Path>,
42+
) -> Result<(), Error> {
43+
let mut builder = TlsConnector::builder();
44+
builder
45+
.danger_accept_invalid_certs(accept_invalid_certs)
46+
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
47+
48+
if !accept_invalid_certs {
49+
if let Some(ca) = root_cert_path {
50+
let data = fs::read(ca).await?;
51+
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
52+
53+
builder.add_root_certificate(cert);
54+
}
55+
}
56+
57+
#[cfg(not(feature = "_rt-async-std"))]
58+
let connector = builder.build().map_err(Error::tls)?;
59+
60+
#[cfg(feature = "_rt-async-std")]
61+
let connector = builder;
62+
3263
let stream = match replace(self, MaybeTlsStream::Upgrading) {
3364
MaybeTlsStream::Raw(stream) => stream,
3465

@@ -45,7 +76,7 @@ where
4576
};
4677

4778
*self = MaybeTlsStream::Tls(
48-
connector
79+
sqlx_rt::TlsConnector::from(connector)
4980
.connect(host, stream)
5081
.await
5182
.map_err(|err| Error::Tls(err.into()))?,

sqlx-core/src/postgres/connection/tls.rs

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
use bytes::Bytes;
2-
use sqlx_rt::{
3-
fs,
4-
native_tls::{Certificate, TlsConnector},
5-
};
62

73
use crate::error::Error;
84
use crate::postgres::connection::stream::PgStream;
@@ -63,34 +59,20 @@ async fn upgrade(stream: &mut PgStream, options: &PgConnectOptions) -> Result<bo
6359
}
6460
}
6561

66-
// FIXME: de-duplicate with mysql/connection/tls.rs
67-
6862
let accept_invalid_certs = !matches!(
6963
options.ssl_mode,
7064
PgSslMode::VerifyCa | PgSslMode::VerifyFull
7165
);
72-
73-
let mut builder = TlsConnector::builder();
74-
builder
75-
.danger_accept_invalid_certs(accept_invalid_certs)
76-
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, PgSslMode::VerifyFull));
77-
78-
if !accept_invalid_certs {
79-
if let Some(ca) = &options.ssl_root_cert {
80-
let data = fs::read(ca).await?;
81-
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
82-
83-
builder.add_root_certificate(cert);
84-
}
85-
}
86-
87-
#[cfg(not(feature = "_rt-async-std"))]
88-
let connector = builder.build().map_err(Error::tls)?;
89-
90-
#[cfg(feature = "_rt-async-std")]
91-
let connector = builder;
92-
93-
stream.upgrade(&options.host, connector.into()).await?;
66+
let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
67+
68+
stream
69+
.upgrade(
70+
&options.host,
71+
accept_invalid_certs,
72+
accept_invalid_hostnames,
73+
options.ssl_root_cert.as_deref(),
74+
)
75+
.await?;
9476

9577
Ok(true)
9678
}

0 commit comments

Comments
 (0)