Skip to content

Commit 1fd8a7c

Browse files
committed
Support RFC 5077 TLS session ticket reuse
1 parent 97b77f4 commit 1fd8a7c

File tree

7 files changed

+369
-20
lines changed

7 files changed

+369
-20
lines changed

Cargo.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@ readme = "README.md"
1111
vendored = ["openssl/vendored"]
1212

1313
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
14-
security-framework = "0.4.1"
15-
security-framework-sys = "0.4.1"
14+
security-framework = { version = "0.4.4", features = ["session-tickets"] }
15+
security-framework-sys = "0.4.3"
1616
lazy_static = "1.0"
1717
libc = "0.2"
1818
tempfile = "3.0"
1919

2020
[target.'cfg(target_os = "windows")'.dependencies]
21-
schannel = "0.1.16"
21+
schannel = "0.1.18"
2222

2323
[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies]
24+
linked_hash_set = "0.1"
2425
log = "0.4.5"
26+
once_cell = "1.0"
2527
openssl = "0.10.29"
2628
openssl-sys = "0.9.55"
2729
openssl-probe = "0.1"

appveyor.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
image: Visual Studio 2017
12
environment:
23
RUST_VERSION: 1.37.0
34
TARGET: x86_64-pc-windows-msvc

build.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ fn main() {
77
if version >= 0x1_01_00_00_0 {
88
println!("cargo:rustc-cfg=have_min_max_version");
99
}
10+
if version >= 0x1_01_01_00_0 {
11+
println!("cargo:rustc-cfg=ossl111");
12+
}
1013
}
1114

1215
if let Ok(version) = env::var("DEP_OPENSSL_LIBRESSL_VERSION_NUMBER") {

src/imp/openssl.rs

Lines changed: 204 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
1+
extern crate linked_hash_set;
2+
extern crate once_cell;
13
extern crate openssl;
24
extern crate openssl_probe;
35

6+
use self::linked_hash_set::LinkedHashSet;
7+
use self::once_cell::sync::OnceCell;
48
use self::openssl::error::ErrorStack;
9+
use self::openssl::ex_data::Index;
510
use self::openssl::hash::MessageDigest;
611
use self::openssl::nid::Nid;
712
use self::openssl::pkcs12::Pkcs12;
813
use self::openssl::pkey::PKey;
914
use self::openssl::ssl::{
10-
self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
11-
SslVerifyMode,
15+
self, MidHandshakeSslStream, Ssl, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
16+
SslSession, SslSessionCacheMode, SslSessionRef, SslVerifyMode,
1217
};
1318
use self::openssl::x509::{X509, store::X509StoreBuilder, X509VerifyResult};
19+
use std::borrow::Borrow;
20+
use std::collections::hash_map::{Entry, HashMap};
1421
use std::error;
1522
use std::fmt;
23+
use std::hash::{Hash, Hasher};
1624
use std::io;
17-
use std::sync::Once;
25+
use std::sync::{Arc, Mutex, Once};
1826

1927
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
2028
use self::openssl::pkey::Private;
@@ -248,6 +256,8 @@ pub struct TlsConnector {
248256
use_sni: bool,
249257
accept_invalid_hostnames: bool,
250258
accept_invalid_certs: bool,
259+
session_tickets_enabled: bool,
260+
session_cache: Arc<Mutex<SessionCache>>,
251261
}
252262

253263
impl TlsConnector {
@@ -277,11 +287,37 @@ impl TlsConnector {
277287
#[cfg(target_os = "android")]
278288
load_android_root_certs(&mut connector)?;
279289

290+
let session_cache = Arc::new(Mutex::new(SessionCache::new()));
291+
if builder.session_tickets_enabled {
292+
connector.set_session_cache_mode(SslSessionCacheMode::CLIENT);
293+
294+
connector.set_new_session_callback({
295+
let session_cache = session_cache.clone();
296+
move |ssl, session| {
297+
if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
298+
if let Ok(mut session_cache) = session_cache.lock() {
299+
session_cache.insert(key.clone(), session);
300+
}
301+
}
302+
}
303+
});
304+
connector.set_remove_session_callback({
305+
let session_cache = session_cache.clone();
306+
move |_, session| {
307+
if let Ok(mut session_cache) = session_cache.lock() {
308+
session_cache.remove(session);
309+
}
310+
}
311+
});
312+
}
313+
280314
Ok(TlsConnector {
281315
connector: connector.build(),
282316
use_sni: builder.use_sni,
283317
accept_invalid_hostnames: builder.accept_invalid_hostnames,
284318
accept_invalid_certs: builder.accept_invalid_certs,
319+
session_tickets_enabled: builder.session_tickets_enabled,
320+
session_cache,
285321
})
286322
}
287323

@@ -297,6 +333,23 @@ impl TlsConnector {
297333
if self.accept_invalid_certs {
298334
ssl.set_verify(SslVerifyMode::NONE);
299335
}
336+
if self.session_tickets_enabled {
337+
let key = SessionKey {
338+
host: domain.to_string(),
339+
};
340+
341+
if let Ok(mut session_cache) = self.session_cache.lock() {
342+
if let Some(session) = session_cache.get(&key) {
343+
// Note: the `unsafe`-ty here is because the `session` is required to come from the
344+
// same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
345+
// pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
346+
unsafe { ssl.set_session(&session)? };
347+
}
348+
}
349+
350+
let idx = key_index()?;
351+
ssl.set_ex_data(idx, key);
352+
}
300353

301354
let s = ssl.connect(domain, stream)?;
302355
Ok(TlsStream(s))
@@ -412,3 +465,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
412465
self.0.flush()
413466
}
414467
}
468+
469+
fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
470+
static IDX: OnceCell<Index<Ssl, SessionKey>> = OnceCell::new();
471+
IDX.get_or_try_init(|| Ssl::new_ex_index()).map(|v| *v)
472+
}
473+
474+
#[derive(Hash, PartialEq, Eq, Clone)]
475+
pub struct SessionKey {
476+
pub host: String,
477+
}
478+
479+
#[derive(Clone)]
480+
struct HashSession(SslSession);
481+
482+
impl PartialEq for HashSession {
483+
fn eq(&self, other: &HashSession) -> bool {
484+
self.0.id() == other.0.id()
485+
}
486+
}
487+
488+
impl Eq for HashSession {}
489+
490+
impl Hash for HashSession {
491+
fn hash<H>(&self, state: &mut H)
492+
where
493+
H: Hasher,
494+
{
495+
self.0.id().hash(state);
496+
}
497+
}
498+
499+
impl Borrow<[u8]> for HashSession {
500+
fn borrow(&self) -> &[u8] {
501+
self.0.id()
502+
}
503+
}
504+
505+
pub struct SessionCache {
506+
sessions: HashMap<SessionKey, LinkedHashSet<HashSession>>,
507+
reverse: HashMap<HashSession, SessionKey>,
508+
}
509+
510+
impl SessionCache {
511+
pub fn new() -> SessionCache {
512+
SessionCache {
513+
sessions: HashMap::new(),
514+
reverse: HashMap::new(),
515+
}
516+
}
517+
518+
pub fn insert(&mut self, key: SessionKey, session: SslSession) {
519+
let session = HashSession(session);
520+
521+
self.sessions
522+
.entry(key.clone())
523+
.or_insert_with(LinkedHashSet::new)
524+
.insert(session.clone());
525+
self.reverse.insert(session.clone(), key);
526+
}
527+
528+
pub fn get(&mut self, key: &SessionKey) -> Option<SslSession> {
529+
let session = {
530+
let sessions = self.sessions.get_mut(key)?;
531+
sessions.front().cloned()?.0
532+
};
533+
534+
#[cfg(ossl111)]
535+
{
536+
use self::openssl::ssl::SslVersion;
537+
538+
// https://tools.ietf.org/html/rfc8446#appendix-C.4
539+
// OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
540+
// that concurrent handshakes don't end up with the same session.
541+
if session.protocol_version() == SslVersion::TLS1_3 {
542+
self.remove(&session);
543+
}
544+
}
545+
546+
Some(session)
547+
}
548+
549+
pub fn remove(&mut self, session: &SslSessionRef) {
550+
let key = match self.reverse.remove(session.id()) {
551+
Some(key) => key,
552+
None => return,
553+
};
554+
555+
if let Entry::Occupied(mut sessions) = self.sessions.entry(key) {
556+
sessions.get_mut().remove(session.id());
557+
if sessions.get().is_empty() {
558+
sessions.remove();
559+
}
560+
}
561+
}
562+
}
563+
564+
#[cfg(test)]
565+
mod tests {
566+
use std::io::{Read, Write};
567+
use std::net::TcpStream;
568+
569+
use crate::TlsConnector;
570+
571+
fn connect_and_assert(tls: &TlsConnector, domain: &str, port: u16, should_resume: bool) {
572+
let s = TcpStream::connect((domain, port)).unwrap();
573+
let mut stream = tls.connect(domain, s).unwrap();
574+
575+
// Must write to the stream, as OpenSSL doesn't appear to call the
576+
// session callback until we do.
577+
stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
578+
let mut result = vec![];
579+
stream.read_to_end(&mut result).unwrap();
580+
581+
assert_eq!((stream.0).0.ssl().session_reused(), should_resume);
582+
583+
// Must shut down properly, or OpenSSL will invalidate the session.
584+
stream.shutdown().unwrap();
585+
}
586+
587+
#[test]
588+
fn connect_no_session_ticket_resumption() {
589+
let tls = TlsConnector::new().unwrap();
590+
connect_and_assert(&tls, "google.com", 443, false);
591+
connect_and_assert(&tls, "google.com", 443, false);
592+
}
593+
594+
#[test]
595+
fn connect_session_ticket_resumption() {
596+
let mut builder = TlsConnector::builder();
597+
builder.session_tickets_enabled(true);
598+
let tls = builder.build().unwrap();
599+
600+
connect_and_assert(&tls, "google.com", 443, false);
601+
connect_and_assert(&tls, "google.com", 443, true);
602+
}
603+
604+
#[test]
605+
fn connect_session_ticket_resumption_two_sites() {
606+
let mut builder = TlsConnector::builder();
607+
builder.session_tickets_enabled(true);
608+
let tls = builder.build().unwrap();
609+
610+
connect_and_assert(&tls, "google.com", 443, false);
611+
connect_and_assert(&tls, "mozilla.org", 443, false);
612+
connect_and_assert(&tls, "google.com", 443, true);
613+
connect_and_assert(&tls, "mozilla.org", 443, true);
614+
}
615+
}

0 commit comments

Comments
 (0)