1
+ extern crate linked_hash_set;
2
+ extern crate once_cell;
1
3
extern crate openssl;
2
4
extern crate openssl_probe;
3
5
6
+ use self :: linked_hash_set:: LinkedHashSet ;
7
+ use self :: once_cell:: sync:: OnceCell ;
4
8
use self :: openssl:: error:: ErrorStack ;
9
+ use self :: openssl:: ex_data:: Index ;
5
10
use self :: openssl:: hash:: MessageDigest ;
6
11
use self :: openssl:: nid:: Nid ;
7
12
use self :: openssl:: pkcs12:: Pkcs12 ;
8
13
use self :: openssl:: pkey:: PKey ;
9
14
use self :: openssl:: ssl:: {
10
- self , MidHandshakeSslStream , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
11
- SslVerifyMode ,
15
+ self , MidHandshakeSslStream , Ssl , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
16
+ SslSession , SslSessionCacheMode , SslSessionRef , SslVerifyMode ,
12
17
} ;
13
18
use self :: openssl:: x509:: { store:: X509StoreBuilder , X509VerifyResult , X509 } ;
19
+ use std:: borrow:: Borrow ;
20
+ use std:: collections:: hash_map:: { Entry , HashMap } ;
14
21
use std:: error;
15
22
use std:: fmt;
23
+ use std:: hash:: { Hash , Hasher } ;
16
24
use std:: io;
17
- use std:: sync:: Once ;
25
+ use std:: sync:: { Arc , Mutex , Once } ;
18
26
19
27
use self :: openssl:: pkey:: Private ;
20
28
use { Protocol , TlsAcceptorBuilder , TlsConnectorBuilder } ;
@@ -248,6 +256,8 @@ pub struct TlsConnector {
248
256
use_sni : bool ,
249
257
accept_invalid_hostnames : bool ,
250
258
accept_invalid_certs : bool ,
259
+ session_tickets_enabled : bool ,
260
+ session_cache : Arc < Mutex < SessionCache > > ,
251
261
}
252
262
253
263
impl TlsConnector {
@@ -297,11 +307,37 @@ impl TlsConnector {
297
307
#[ cfg( target_os = "android" ) ]
298
308
load_android_root_certs ( & mut connector) ?;
299
309
310
+ let session_cache = Arc :: new ( Mutex :: new ( SessionCache :: new ( ) ) ) ;
311
+ if builder. session_tickets_enabled {
312
+ connector. set_session_cache_mode ( SslSessionCacheMode :: CLIENT ) ;
313
+
314
+ connector. set_new_session_callback ( {
315
+ let session_cache = session_cache. clone ( ) ;
316
+ move |ssl, session| {
317
+ if let Some ( key) = key_index ( ) . ok ( ) . and_then ( |idx| ssl. ex_data ( idx) ) {
318
+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
319
+ session_cache. insert ( key. clone ( ) , session) ;
320
+ }
321
+ }
322
+ }
323
+ } ) ;
324
+ connector. set_remove_session_callback ( {
325
+ let session_cache = session_cache. clone ( ) ;
326
+ move |_, session| {
327
+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
328
+ session_cache. remove ( session) ;
329
+ }
330
+ }
331
+ } ) ;
332
+ }
333
+
300
334
Ok ( TlsConnector {
301
335
connector : connector. build ( ) ,
302
336
use_sni : builder. use_sni ,
303
337
accept_invalid_hostnames : builder. accept_invalid_hostnames ,
304
338
accept_invalid_certs : builder. accept_invalid_certs ,
339
+ session_tickets_enabled : builder. session_tickets_enabled ,
340
+ session_cache,
305
341
} )
306
342
}
307
343
@@ -317,6 +353,23 @@ impl TlsConnector {
317
353
if self . accept_invalid_certs {
318
354
ssl. set_verify ( SslVerifyMode :: NONE ) ;
319
355
}
356
+ if self . session_tickets_enabled {
357
+ let key = SessionKey {
358
+ host : domain. to_string ( ) ,
359
+ } ;
360
+
361
+ if let Ok ( mut session_cache) = self . session_cache . lock ( ) {
362
+ if let Some ( session) = session_cache. get ( & key) {
363
+ // Note: the `unsafe`-ty here is because the `session` is required to come from the
364
+ // same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
365
+ // pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
366
+ unsafe { ssl. set_session ( & session) ? } ;
367
+ }
368
+ }
369
+
370
+ let idx = key_index ( ) ?;
371
+ ssl. set_ex_data ( idx, key) ;
372
+ }
320
373
321
374
let s = ssl. connect ( domain, stream) ?;
322
375
Ok ( TlsStream ( s) )
@@ -452,3 +505,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
452
505
self . 0 . flush ( )
453
506
}
454
507
}
508
+
509
+ fn key_index ( ) -> Result < Index < Ssl , SessionKey > , ErrorStack > {
510
+ static IDX : OnceCell < Index < Ssl , SessionKey > > = OnceCell :: new ( ) ;
511
+ IDX . get_or_try_init ( || Ssl :: new_ex_index ( ) ) . map ( |v| * v)
512
+ }
513
+
514
+ #[ derive( Hash , PartialEq , Eq , Clone ) ]
515
+ pub struct SessionKey {
516
+ pub host : String ,
517
+ }
518
+
519
+ #[ derive( Clone ) ]
520
+ struct HashSession ( SslSession ) ;
521
+
522
+ impl PartialEq for HashSession {
523
+ fn eq ( & self , other : & HashSession ) -> bool {
524
+ self . 0 . id ( ) == other. 0 . id ( )
525
+ }
526
+ }
527
+
528
+ impl Eq for HashSession { }
529
+
530
+ impl Hash for HashSession {
531
+ fn hash < H > ( & self , state : & mut H )
532
+ where
533
+ H : Hasher ,
534
+ {
535
+ self . 0 . id ( ) . hash ( state) ;
536
+ }
537
+ }
538
+
539
+ impl Borrow < [ u8 ] > for HashSession {
540
+ fn borrow ( & self ) -> & [ u8 ] {
541
+ self . 0 . id ( )
542
+ }
543
+ }
544
+
545
+ pub struct SessionCache {
546
+ sessions : HashMap < SessionKey , LinkedHashSet < HashSession > > ,
547
+ reverse : HashMap < HashSession , SessionKey > ,
548
+ }
549
+
550
+ impl SessionCache {
551
+ pub fn new ( ) -> SessionCache {
552
+ SessionCache {
553
+ sessions : HashMap :: new ( ) ,
554
+ reverse : HashMap :: new ( ) ,
555
+ }
556
+ }
557
+
558
+ pub fn insert ( & mut self , key : SessionKey , session : SslSession ) {
559
+ let session = HashSession ( session) ;
560
+
561
+ self . sessions
562
+ . entry ( key. clone ( ) )
563
+ . or_insert_with ( LinkedHashSet :: new)
564
+ . insert ( session. clone ( ) ) ;
565
+ self . reverse . insert ( session. clone ( ) , key) ;
566
+ }
567
+
568
+ pub fn get ( & mut self , key : & SessionKey ) -> Option < SslSession > {
569
+ let session = {
570
+ let sessions = self . sessions . get_mut ( key) ?;
571
+ sessions. front ( ) . cloned ( ) ?. 0
572
+ } ;
573
+
574
+ #[ cfg( ossl111) ]
575
+ {
576
+ use self :: openssl:: ssl:: SslVersion ;
577
+
578
+ // https://tools.ietf.org/html/rfc8446#appendix-C.4
579
+ // OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
580
+ // that concurrent handshakes don't end up with the same session.
581
+ if session. protocol_version ( ) == SslVersion :: TLS1_3 {
582
+ self . remove ( & session) ;
583
+ }
584
+ }
585
+
586
+ Some ( session)
587
+ }
588
+
589
+ pub fn remove ( & mut self , session : & SslSessionRef ) {
590
+ let key = match self . reverse . remove ( session. id ( ) ) {
591
+ Some ( key) => key,
592
+ None => return ,
593
+ } ;
594
+
595
+ if let Entry :: Occupied ( mut sessions) = self . sessions . entry ( key) {
596
+ sessions. get_mut ( ) . remove ( session. id ( ) ) ;
597
+ if sessions. get ( ) . is_empty ( ) {
598
+ sessions. remove ( ) ;
599
+ }
600
+ }
601
+ }
602
+ }
603
+
604
+ #[ cfg( test) ]
605
+ mod tests {
606
+ use std:: io:: { Read , Write } ;
607
+ use std:: net:: TcpStream ;
608
+
609
+ use crate :: TlsConnector ;
610
+
611
+ fn connect_and_assert ( tls : & TlsConnector , domain : & str , port : u16 , should_resume : bool ) {
612
+ let s = TcpStream :: connect ( ( domain, port) ) . unwrap ( ) ;
613
+ let mut stream = tls. connect ( domain, s) . unwrap ( ) ;
614
+
615
+ // Must write to the stream, as OpenSSL doesn't appear to call the
616
+ // session callback until we do.
617
+ stream. write_all ( b"GET / HTTP/1.0\r \n \r \n " ) . unwrap ( ) ;
618
+ let mut result = vec ! [ ] ;
619
+ stream. read_to_end ( & mut result) . unwrap ( ) ;
620
+
621
+ assert_eq ! ( ( stream. 0 ) . 0 . ssl( ) . session_reused( ) , should_resume) ;
622
+
623
+ // Must shut down properly, or OpenSSL will invalidate the session.
624
+ stream. shutdown ( ) . unwrap ( ) ;
625
+ }
626
+
627
+ #[ test]
628
+ fn connect_no_session_ticket_resumption ( ) {
629
+ let tls = TlsConnector :: new ( ) . unwrap ( ) ;
630
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
631
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
632
+ }
633
+
634
+ #[ test]
635
+ fn connect_session_ticket_resumption ( ) {
636
+ let mut builder = TlsConnector :: builder ( ) ;
637
+ builder. session_tickets_enabled ( true ) ;
638
+ let tls = builder. build ( ) . unwrap ( ) ;
639
+
640
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
641
+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
642
+ }
643
+
644
+ #[ test]
645
+ fn connect_session_ticket_resumption_two_sites ( ) {
646
+ let mut builder = TlsConnector :: builder ( ) ;
647
+ builder. session_tickets_enabled ( true ) ;
648
+ let tls = builder. build ( ) . unwrap ( ) ;
649
+
650
+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
651
+ connect_and_assert ( & tls, "mozilla.org" , 443 , false ) ;
652
+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
653
+ connect_and_assert ( & tls, "mozilla.org" , 443 , true ) ;
654
+ }
655
+ }
0 commit comments