Skip to content

Commit 1f52ab4

Browse files
committed
Support RFC 5077 TLS session ticket reuse
1 parent 41522da commit 1f52ab4

File tree

6 files changed

+353
-14
lines changed

6 files changed

+353
-14
lines changed

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@ vendored = ["openssl/vendored"]
1616
alpn = ["security-framework/alpn"]
1717

1818
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
19-
security-framework = "2.0.0"
19+
security-framework = { version = "2.0.0", features = ["session-tickets"] }
2020
security-framework-sys = "2.0.0"
2121
lazy_static = "1.4.0"
2222
libc = "0.2"
2323
tempfile = "3.1.0"
2424

2525
[target.'cfg(target_os = "windows")'.dependencies]
26-
schannel = "0.1.16"
26+
schannel = "0.1.18"
2727

2828
[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies]
29+
linked_hash_set = "0.1"
2930
log = "0.4.5"
31+
once_cell = "1.0"
3032
openssl = "0.10.29"
3133
openssl-sys = "0.9.55"
3234
openssl-probe = "0.1"

build.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ fn main() {
77
if version >= 0x1_01_00_00_0 {
88
println!("cargo:rustc-cfg=have_min_max_version");
99
}
10+
if version >= 0x1_01_01_00_0 {
11+
println!("cargo:rustc-cfg=ossl111");
12+
}
1013
}
1114

1215
if let Ok(version) = env::var("DEP_OPENSSL_LIBRESSL_VERSION_NUMBER") {

src/imp/openssl.rs

Lines changed: 204 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
1+
extern crate linked_hash_set;
2+
extern crate once_cell;
13
extern crate openssl;
24
extern crate openssl_probe;
35

6+
use self::linked_hash_set::LinkedHashSet;
7+
use self::once_cell::sync::OnceCell;
48
use self::openssl::error::ErrorStack;
9+
use self::openssl::ex_data::Index;
510
use self::openssl::hash::MessageDigest;
611
use self::openssl::nid::Nid;
712
use self::openssl::pkcs12::Pkcs12;
813
use self::openssl::pkey::PKey;
914
use self::openssl::ssl::{
10-
self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
11-
SslVerifyMode,
15+
self, MidHandshakeSslStream, Ssl, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
16+
SslSession, SslSessionCacheMode, SslSessionRef, SslVerifyMode,
1217
};
1318
use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
19+
use std::borrow::Borrow;
20+
use std::collections::hash_map::{Entry, HashMap};
1421
use std::error;
1522
use std::fmt;
23+
use std::hash::{Hash, Hasher};
1624
use std::io;
17-
use std::sync::Once;
25+
use std::sync::{Arc, Mutex, Once};
1826

1927
use self::openssl::pkey::Private;
2028
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
@@ -248,6 +256,8 @@ pub struct TlsConnector {
248256
use_sni: bool,
249257
accept_invalid_hostnames: bool,
250258
accept_invalid_certs: bool,
259+
session_tickets_enabled: bool,
260+
session_cache: Arc<Mutex<SessionCache>>,
251261
}
252262

253263
impl TlsConnector {
@@ -297,11 +307,37 @@ impl TlsConnector {
297307
#[cfg(target_os = "android")]
298308
load_android_root_certs(&mut connector)?;
299309

310+
let session_cache = Arc::new(Mutex::new(SessionCache::new()));
311+
if builder.session_tickets_enabled {
312+
connector.set_session_cache_mode(SslSessionCacheMode::CLIENT);
313+
314+
connector.set_new_session_callback({
315+
let session_cache = session_cache.clone();
316+
move |ssl, session| {
317+
if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
318+
if let Ok(mut session_cache) = session_cache.lock() {
319+
session_cache.insert(key.clone(), session);
320+
}
321+
}
322+
}
323+
});
324+
connector.set_remove_session_callback({
325+
let session_cache = session_cache.clone();
326+
move |_, session| {
327+
if let Ok(mut session_cache) = session_cache.lock() {
328+
session_cache.remove(session);
329+
}
330+
}
331+
});
332+
}
333+
300334
Ok(TlsConnector {
301335
connector: connector.build(),
302336
use_sni: builder.use_sni,
303337
accept_invalid_hostnames: builder.accept_invalid_hostnames,
304338
accept_invalid_certs: builder.accept_invalid_certs,
339+
session_tickets_enabled: builder.session_tickets_enabled,
340+
session_cache,
305341
})
306342
}
307343

@@ -317,6 +353,23 @@ impl TlsConnector {
317353
if self.accept_invalid_certs {
318354
ssl.set_verify(SslVerifyMode::NONE);
319355
}
356+
if self.session_tickets_enabled {
357+
let key = SessionKey {
358+
host: domain.to_string(),
359+
};
360+
361+
if let Ok(mut session_cache) = self.session_cache.lock() {
362+
if let Some(session) = session_cache.get(&key) {
363+
// Note: the `unsafe`-ty here is because the `session` is required to come from the
364+
// same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
365+
// pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
366+
unsafe { ssl.set_session(&session)? };
367+
}
368+
}
369+
370+
let idx = key_index()?;
371+
ssl.set_ex_data(idx, key);
372+
}
320373

321374
let s = ssl.connect(domain, stream)?;
322375
Ok(TlsStream(s))
@@ -452,3 +505,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
452505
self.0.flush()
453506
}
454507
}
508+
509+
fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
510+
static IDX: OnceCell<Index<Ssl, SessionKey>> = OnceCell::new();
511+
IDX.get_or_try_init(|| Ssl::new_ex_index()).map(|v| *v)
512+
}
513+
514+
#[derive(Hash, PartialEq, Eq, Clone)]
515+
pub struct SessionKey {
516+
pub host: String,
517+
}
518+
519+
#[derive(Clone)]
520+
struct HashSession(SslSession);
521+
522+
impl PartialEq for HashSession {
523+
fn eq(&self, other: &HashSession) -> bool {
524+
self.0.id() == other.0.id()
525+
}
526+
}
527+
528+
impl Eq for HashSession {}
529+
530+
impl Hash for HashSession {
531+
fn hash<H>(&self, state: &mut H)
532+
where
533+
H: Hasher,
534+
{
535+
self.0.id().hash(state);
536+
}
537+
}
538+
539+
impl Borrow<[u8]> for HashSession {
540+
fn borrow(&self) -> &[u8] {
541+
self.0.id()
542+
}
543+
}
544+
545+
pub struct SessionCache {
546+
sessions: HashMap<SessionKey, LinkedHashSet<HashSession>>,
547+
reverse: HashMap<HashSession, SessionKey>,
548+
}
549+
550+
impl SessionCache {
551+
pub fn new() -> SessionCache {
552+
SessionCache {
553+
sessions: HashMap::new(),
554+
reverse: HashMap::new(),
555+
}
556+
}
557+
558+
pub fn insert(&mut self, key: SessionKey, session: SslSession) {
559+
let session = HashSession(session);
560+
561+
self.sessions
562+
.entry(key.clone())
563+
.or_insert_with(LinkedHashSet::new)
564+
.insert(session.clone());
565+
self.reverse.insert(session.clone(), key);
566+
}
567+
568+
pub fn get(&mut self, key: &SessionKey) -> Option<SslSession> {
569+
let session = {
570+
let sessions = self.sessions.get_mut(key)?;
571+
sessions.front().cloned()?.0
572+
};
573+
574+
#[cfg(ossl111)]
575+
{
576+
use self::openssl::ssl::SslVersion;
577+
578+
// https://tools.ietf.org/html/rfc8446#appendix-C.4
579+
// OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
580+
// that concurrent handshakes don't end up with the same session.
581+
if session.protocol_version() == SslVersion::TLS1_3 {
582+
self.remove(&session);
583+
}
584+
}
585+
586+
Some(session)
587+
}
588+
589+
pub fn remove(&mut self, session: &SslSessionRef) {
590+
let key = match self.reverse.remove(session.id()) {
591+
Some(key) => key,
592+
None => return,
593+
};
594+
595+
if let Entry::Occupied(mut sessions) = self.sessions.entry(key) {
596+
sessions.get_mut().remove(session.id());
597+
if sessions.get().is_empty() {
598+
sessions.remove();
599+
}
600+
}
601+
}
602+
}
603+
604+
#[cfg(test)]
605+
mod tests {
606+
use std::io::{Read, Write};
607+
use std::net::TcpStream;
608+
609+
use crate::TlsConnector;
610+
611+
fn connect_and_assert(tls: &TlsConnector, domain: &str, port: u16, should_resume: bool) {
612+
let s = TcpStream::connect((domain, port)).unwrap();
613+
let mut stream = tls.connect(domain, s).unwrap();
614+
615+
// Must write to the stream, as OpenSSL doesn't appear to call the
616+
// session callback until we do.
617+
stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
618+
let mut result = vec![];
619+
stream.read_to_end(&mut result).unwrap();
620+
621+
assert_eq!((stream.0).0.ssl().session_reused(), should_resume);
622+
623+
// Must shut down properly, or OpenSSL will invalidate the session.
624+
stream.shutdown().unwrap();
625+
}
626+
627+
#[test]
628+
fn connect_no_session_ticket_resumption() {
629+
let tls = TlsConnector::new().unwrap();
630+
connect_and_assert(&tls, "google.com", 443, false);
631+
connect_and_assert(&tls, "google.com", 443, false);
632+
}
633+
634+
#[test]
635+
fn connect_session_ticket_resumption() {
636+
let mut builder = TlsConnector::builder();
637+
builder.session_tickets_enabled(true);
638+
let tls = builder.build().unwrap();
639+
640+
connect_and_assert(&tls, "google.com", 443, false);
641+
connect_and_assert(&tls, "google.com", 443, true);
642+
}
643+
644+
#[test]
645+
fn connect_session_ticket_resumption_two_sites() {
646+
let mut builder = TlsConnector::builder();
647+
builder.session_tickets_enabled(true);
648+
let tls = builder.build().unwrap();
649+
650+
connect_and_assert(&tls, "google.com", 443, false);
651+
connect_and_assert(&tls, "mozilla.org", 443, false);
652+
connect_and_assert(&tls, "google.com", 443, true);
653+
connect_and_assert(&tls, "mozilla.org", 443, true);
654+
}
655+
}

0 commit comments

Comments
 (0)