1
+ extern crate linked_hash_set;
2
+ extern crate once_cell;
1
3
extern crate openssl;
2
4
extern crate openssl_probe;
3
5
6
+ use self :: linked_hash_set:: LinkedHashSet ;
7
+ use self :: once_cell:: sync:: OnceCell ;
4
8
use self :: openssl:: error:: ErrorStack ;
9
+ use self :: openssl:: ex_data:: Index ;
5
10
use self :: openssl:: hash:: MessageDigest ;
6
11
use self :: openssl:: nid:: Nid ;
7
12
use self :: openssl:: pkcs12:: Pkcs12 ;
8
13
use self :: openssl:: pkey:: PKey ;
9
14
use self :: openssl:: ssl:: {
10
- self , MidHandshakeSslStream , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
11
- SslVerifyMode ,
15
+ self , MidHandshakeSslStream , Ssl , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
16
+ SslSession , SslSessionCacheMode , SslSessionRef , SslVerifyMode ,
12
17
} ;
13
18
use self :: openssl:: x509:: { X509 , store:: X509StoreBuilder , X509VerifyResult } ;
19
+ use std:: borrow:: Borrow ;
20
+ use std:: collections:: hash_map:: { Entry , HashMap } ;
14
21
use std:: error;
15
22
use std:: fmt;
23
+ use std:: hash:: { Hash , Hasher } ;
16
24
use std:: io;
17
- use std:: sync:: Once ;
25
+ use std:: sync:: { Arc , Mutex , Once } ;
18
26
19
27
use { Protocol , TlsAcceptorBuilder , TlsConnectorBuilder } ;
20
28
use self :: openssl:: pkey:: Private ;
@@ -248,6 +256,8 @@ pub struct TlsConnector {
248
256
use_sni : bool ,
249
257
accept_invalid_hostnames : bool ,
250
258
accept_invalid_certs : bool ,
259
+ session_tickets_enabled : bool ,
260
+ session_cache : Arc < Mutex < SessionCache > > ,
251
261
}
252
262
253
263
impl TlsConnector {
@@ -277,11 +287,37 @@ impl TlsConnector {
277
287
#[ cfg( target_os = "android" ) ]
278
288
load_android_root_certs ( & mut connector) ?;
279
289
290
+ let session_cache = Arc :: new ( Mutex :: new ( SessionCache :: new ( ) ) ) ;
291
+ if builder. session_tickets_enabled {
292
+ connector. set_session_cache_mode ( SslSessionCacheMode :: CLIENT ) ;
293
+
294
+ connector. set_new_session_callback ( {
295
+ let session_cache = session_cache. clone ( ) ;
296
+ move |ssl, session| {
297
+ if let Some ( key) = key_index ( ) . ok ( ) . and_then ( |idx| ssl. ex_data ( idx) ) {
298
+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
299
+ session_cache. insert ( key. clone ( ) , session) ;
300
+ }
301
+ }
302
+ }
303
+ } ) ;
304
+ connector. set_remove_session_callback ( {
305
+ let session_cache = session_cache. clone ( ) ;
306
+ move |_, session| {
307
+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
308
+ session_cache. remove ( session) ;
309
+ }
310
+ }
311
+ } ) ;
312
+ }
313
+
280
314
Ok ( TlsConnector {
281
315
connector : connector. build ( ) ,
282
316
use_sni : builder. use_sni ,
283
317
accept_invalid_hostnames : builder. accept_invalid_hostnames ,
284
318
accept_invalid_certs : builder. accept_invalid_certs ,
319
+ session_tickets_enabled : builder. session_tickets_enabled ,
320
+ session_cache,
285
321
} )
286
322
}
287
323
@@ -297,6 +333,23 @@ impl TlsConnector {
297
333
if self . accept_invalid_certs {
298
334
ssl. set_verify ( SslVerifyMode :: NONE ) ;
299
335
}
336
+ if self . session_tickets_enabled {
337
+ let key = SessionKey {
338
+ host : domain. to_string ( ) ,
339
+ } ;
340
+
341
+ if let Ok ( mut session_cache) = self . session_cache . lock ( ) {
342
+ if let Some ( session) = session_cache. get ( & key) {
343
+ // Note: the `unsafe`-ty here is because the `session` is required to come from the
344
+ // same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
345
+ // pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
346
+ unsafe { ssl. set_session ( & session) ? } ;
347
+ }
348
+ }
349
+
350
+ let idx = key_index ( ) ?;
351
+ ssl. set_ex_data ( idx, key) ;
352
+ }
300
353
301
354
let s = ssl. connect ( domain, stream) ?;
302
355
Ok ( TlsStream ( s) )
@@ -412,3 +465,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
412
465
self . 0 . flush ( )
413
466
}
414
467
}
468
+
469
+ fn key_index ( ) -> Result < Index < Ssl , SessionKey > , ErrorStack > {
470
+ static IDX : OnceCell < Index < Ssl , SessionKey > > = OnceCell :: new ( ) ;
471
+ IDX . get_or_try_init ( || Ssl :: new_ex_index ( ) ) . map ( |v| * v)
472
+ }
473
+
474
+ #[ derive( Hash , PartialEq , Eq , Clone ) ]
475
+ pub struct SessionKey {
476
+ pub host : String ,
477
+ }
478
+
479
+ #[ derive( Clone ) ]
480
+ struct HashSession ( SslSession ) ;
481
+
482
+ impl PartialEq for HashSession {
483
+ fn eq ( & self , other : & HashSession ) -> bool {
484
+ self . 0 . id ( ) == other. 0 . id ( )
485
+ }
486
+ }
487
+
488
+ impl Eq for HashSession { }
489
+
490
+ impl Hash for HashSession {
491
+ fn hash < H > ( & self , state : & mut H )
492
+ where
493
+ H : Hasher ,
494
+ {
495
+ self . 0 . id ( ) . hash ( state) ;
496
+ }
497
+ }
498
+
499
+ impl Borrow < [ u8 ] > for HashSession {
500
+ fn borrow ( & self ) -> & [ u8 ] {
501
+ self . 0 . id ( )
502
+ }
503
+ }
504
+
505
+ pub struct SessionCache {
506
+ sessions : HashMap < SessionKey , LinkedHashSet < HashSession > > ,
507
+ reverse : HashMap < HashSession , SessionKey > ,
508
+ }
509
+
510
+ impl SessionCache {
511
+ pub fn new ( ) -> SessionCache {
512
+ SessionCache {
513
+ sessions : HashMap :: new ( ) ,
514
+ reverse : HashMap :: new ( ) ,
515
+ }
516
+ }
517
+
518
+ pub fn insert ( & mut self , key : SessionKey , session : SslSession ) {
519
+ let session = HashSession ( session) ;
520
+
521
+ self . sessions
522
+ . entry ( key. clone ( ) )
523
+ . or_insert_with ( LinkedHashSet :: new)
524
+ . insert ( session. clone ( ) ) ;
525
+ self . reverse . insert ( session. clone ( ) , key) ;
526
+ }
527
+
528
+ pub fn get ( & mut self , key : & SessionKey ) -> Option < SslSession > {
529
+ let session = {
530
+ let sessions = self . sessions . get_mut ( key) ?;
531
+ sessions. front ( ) . cloned ( ) ?. 0
532
+ } ;
533
+
534
+ #[ cfg( ossl111) ]
535
+ {
536
+ use self :: openssl:: ssl:: SslVersion ;
537
+
538
+ // https://tools.ietf.org/html/rfc8446#appendix-C.4
539
+ // OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
540
+ // that concurrent handshakes don't end up with the same session.
541
+ if session. protocol_version ( ) == SslVersion :: TLS1_3 {
542
+ self . remove ( & session) ;
543
+ }
544
+ }
545
+
546
+ Some ( session)
547
+ }
548
+
549
+ pub fn remove ( & mut self , session : & SslSessionRef ) {
550
+ let key = match self . reverse . remove ( session. id ( ) ) {
551
+ Some ( key) => key,
552
+ None => return ,
553
+ } ;
554
+
555
+ if let Entry :: Occupied ( mut sessions) = self . sessions . entry ( key) {
556
+ sessions. get_mut ( ) . remove ( session. id ( ) ) ;
557
+ if sessions. get ( ) . is_empty ( ) {
558
+ sessions. remove ( ) ;
559
+ }
560
+ }
561
+ }
562
+ }
563
+
564
+ #[ cfg( test) ]
565
+ mod tests {
566
+ use std:: io:: { Read , Write } ;
567
+ use std:: net:: TcpStream ;
568
+
569
+ use crate :: TlsConnector ;
570
+
571
+ fn connect_and_assert ( tls : & TlsConnector , domain : & str , port : u16 , should_resume : bool ) {
572
+ let s = TcpStream :: connect ( ( domain, port) ) . unwrap ( ) ;
573
+ let mut stream = tls. connect ( domain, s) . unwrap ( ) ;
574
+
575
+ // Must write to the stream, as OpenSSL doesn't appear to call the
576
+ // session callback until we do.
577
+ stream. write_all ( b"GET / HTTP/1.0\r \n \r \n " ) . unwrap ( ) ;
578
+ let mut result = vec ! [ ] ;
579
+ stream. read_to_end ( & mut result) . unwrap ( ) ;
580
+
581
+ assert_eq ! ( ( stream. 0 ) . 0 . ssl( ) . session_reused( ) , should_resume) ;
582
+
583
+ // Must shut down properly, or OpenSSL will invalidate the session.
584
+ stream. shutdown ( ) . unwrap ( ) ;
585
+ }
586
+
587
+ #[ test]
588
+ fn connect_no_session_ticket_resumption ( ) {
589
+ let tls = TlsConnector :: new ( ) . unwrap ( ) ;
590
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
591
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
592
+ }
593
+
594
+ #[ test]
595
+ fn connect_session_ticket_resumption ( ) {
596
+ let mut builder = TlsConnector :: builder ( ) ;
597
+ builder. session_tickets_enabled ( true ) ;
598
+ let tls = builder. build ( ) . unwrap ( ) ;
599
+
600
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
601
+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
602
+ }
603
+
604
+ #[ test]
605
+ fn connect_session_ticket_resumption_two_sites ( ) {
606
+ let mut builder = TlsConnector :: builder ( ) ;
607
+ builder. session_tickets_enabled ( true ) ;
608
+ let tls = builder. build ( ) . unwrap ( ) ;
609
+
610
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
611
+ connect_and_assert ( & tls, "mozilla.org" , 443 , false ) ;
612
+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
613
+ connect_and_assert ( & tls, "mozilla.org" , 443 , true ) ;
614
+ }
615
+ }
0 commit comments