@@ -20,24 +20,21 @@ use {
20
20
std:: marker:: Unpin ,
21
21
std:: pin:: Pin ,
22
22
std:: task:: { Context as TaskContext , Poll } ,
23
+ tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ,
24
+ crate :: ssl:: async_utils:: IoAdapter ,
23
25
} ;
24
26
25
27
use mbedtls_sys:: types:: raw_types:: { c_int, c_uchar, c_void} ;
26
28
use mbedtls_sys:: types:: size_t;
27
29
use mbedtls_sys:: * ;
28
30
29
- #[ cfg( all( feature = "std" , feature = "async" ) ) ]
30
- use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
31
-
32
31
#[ cfg( not( feature = "std" ) ) ]
33
32
use crate :: alloc_prelude:: * ;
34
33
use crate :: alloc:: { List as MbedtlsList } ;
35
34
use crate :: error:: { Error , Result , IntoResult } ;
36
35
use crate :: pk:: Pk ;
37
36
use crate :: private:: UnsafeFrom ;
38
37
use crate :: ssl:: config:: { Config , Version , AuthMode } ;
39
- #[ cfg( all( feature = "std" , feature = "async" ) ) ]
40
- use crate :: ssl:: async_utils:: IoAdapter ;
41
38
use crate :: x509:: { Certificate , Crl , VerifyError } ;
42
39
43
40
pub trait IoCallback {
@@ -199,7 +196,7 @@ define!(
199
196
struct HandshakeContext {
200
197
handshake_ca_cert: Option <Arc <MbedtlsList <Certificate >>>,
201
198
handshake_crl: Option <Arc <Crl >>,
202
-
199
+
203
200
handshake_cert: Vec <Arc <MbedtlsList <Certificate >>>,
204
201
handshake_pk: Vec <Arc <Pk >>,
205
202
} ;
@@ -213,10 +210,10 @@ define!(
213
210
pub struct Context < T > {
214
211
// Base structure used in SNI callback where we cannot determine the io type.
215
212
inner : HandshakeContext ,
216
-
213
+
217
214
// config is used read-only for multiple contexts and is immutable once configured.
218
- config : Arc < Config > ,
219
-
215
+ config : Arc < Config > ,
216
+
220
217
// Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated.
221
218
io : Option < Box < T > > ,
222
219
@@ -240,14 +237,10 @@ impl<'a, T> Into<*mut ssl_context> for &'a mut Context<T> {
240
237
}
241
238
}
242
239
243
- #[ cfg( all( feature = "std" , feature = "async" ) ) ]
244
- pub type AsyncContext < T > = Context < IoAdapter < T > > ;
245
-
246
-
247
240
impl < T > Context < T > {
248
241
pub fn new ( config : Arc < Config > ) -> Self {
249
242
let mut inner = ssl_context:: default ( ) ;
250
-
243
+
251
244
unsafe {
252
245
ssl_init ( & mut inner) ;
253
246
ssl_setup ( & mut inner, ( & * config) . into ( ) ) ;
@@ -258,7 +251,7 @@ impl<T> Context<T> {
258
251
inner,
259
252
handshake_ca_cert : None ,
260
253
handshake_crl : None ,
261
-
254
+
262
255
handshake_cert : vec ! [ ] ,
263
256
handshake_pk : vec ! [ ] ,
264
257
} ,
@@ -268,11 +261,11 @@ impl<T> Context<T> {
268
261
client_transport_id : None ,
269
262
}
270
263
}
271
-
264
+
272
265
pub ( crate ) fn handle ( & self ) -> & :: mbedtls_sys:: ssl_context {
273
266
self . inner . handle ( )
274
267
}
275
-
268
+
276
269
pub ( crate ) fn handle_mut ( & mut self ) -> & mut :: mbedtls_sys:: ssl_context {
277
270
self . inner . handle_mut ( )
278
271
}
@@ -385,23 +378,23 @@ impl<T> Context<T> {
385
378
pub fn config ( & self ) -> & Arc < Config > {
386
379
& self . config
387
380
}
388
-
381
+
389
382
pub fn close ( & mut self ) {
390
383
unsafe {
391
384
ssl_close_notify ( self . into ( ) ) ;
392
385
ssl_set_bio ( self . into ( ) , :: core:: ptr:: null_mut ( ) , None , None , None ) ;
393
386
self . io = None ;
394
387
}
395
388
}
396
-
389
+
397
390
pub fn io ( & self ) -> Option < & T > {
398
391
self . io . as_ref ( ) . map ( |v| & * * v)
399
392
}
400
-
393
+
401
394
pub fn io_mut ( & mut self ) -> Option < & mut T > {
402
395
self . io . as_mut ( ) . map ( |v| & mut * * v)
403
396
}
404
-
397
+
405
398
/// Return the minor number of the negotiated TLS version
406
399
pub fn minor_version ( & self ) -> i32 {
407
400
self . handle ( ) . minor_ver
@@ -433,15 +426,15 @@ impl<T> Context<T> {
433
426
434
427
435
428
// Session specific functions
436
-
429
+
437
430
/// Return the 16-bit ciphersuite identifier.
438
431
/// All assigned ciphersuites are listed by the IANA in
439
432
/// <https://www.iana.org/assignments/tls-parameters/tls-parameters.txt>
440
433
pub fn ciphersuite ( & self ) -> Result < u16 > {
441
434
if self . handle ( ) . session . is_null ( ) {
442
435
return Err ( Error :: SslBadInputData ) ;
443
436
}
444
-
437
+
445
438
Ok ( unsafe { self . handle ( ) . session . as_ref ( ) . unwrap ( ) . ciphersuite as u16 } )
446
439
}
447
440
@@ -578,12 +571,12 @@ impl HandshakeContext {
578
571
self . handshake_ca_cert = None ;
579
572
self . handshake_crl = None ;
580
573
}
581
-
574
+
582
575
pub fn set_authmode ( & mut self , am : AuthMode ) -> Result < ( ) > {
583
576
if self . inner . handshake as * const _ == :: core:: ptr:: null ( ) {
584
577
return Err ( Error :: SslBadInputData ) ;
585
578
}
586
-
579
+
587
580
unsafe { ssl_set_hs_authmode ( self . into ( ) , am as i32 ) }
588
581
Ok ( ( ) )
589
582
}
@@ -637,6 +630,9 @@ impl HandshakeContext {
637
630
}
638
631
}
639
632
633
+ #[ cfg( all( feature = "std" , feature = "async" ) ) ]
634
+ pub type AsyncContext < T > = Context < IoAdapter < T > > ;
635
+
640
636
#[ cfg( all( feature = "std" , feature = "async" ) ) ]
641
637
pub trait IoAsyncCallback {
642
638
unsafe extern "C" fn call_recv_async ( user_data : * mut c_void , data : * mut c_uchar , len : size_t ) -> c_int where Self : Sized ;
@@ -700,7 +696,7 @@ impl<T> std::future::Future for HandshakeFuture<'_, T> {
700
696
fn poll ( mut self : Pin < & mut Self > , ctx : & mut TaskContext ) -> std:: task:: Poll < Self :: Output > {
701
697
self . 0 . io_mut ( ) . ok_or ( Error :: NetInvalidContext ) ?
702
698
. ecx . set ( ctx) ;
703
-
699
+
704
700
let result = match self . 0 . handshake ( ) {
705
701
Err ( Error :: SslWantRead ) |
706
702
Err ( Error :: SslWantWrite ) => {
@@ -709,9 +705,9 @@ impl<T> std::future::Future for HandshakeFuture<'_, T> {
709
705
Err ( e) => Poll :: Ready ( Err ( e) ) ,
710
706
Ok ( ( ) ) => Poll :: Ready ( Ok ( ( ) ) )
711
707
} ;
712
-
708
+
713
709
self . 0 . io_mut ( ) . map ( |v| v. ecx . clear ( ) ) ;
714
-
710
+
715
711
result
716
712
}
717
713
}
@@ -741,7 +737,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncContext<T> {
741
737
) ;
742
738
743
739
self . io = Some ( io) ;
744
- self . inner . reset_handshake ( ) ;
740
+ self . inner . reset_handshake ( ) ;
745
741
}
746
742
747
743
HandshakeFuture ( self ) . await
@@ -762,7 +758,7 @@ impl<T: AsyncRead> AsyncRead for Context<IoAdapter<T>> {
762
758
763
759
self . io_mut ( ) . ok_or ( IoError :: new ( IoErrorKind :: Other , "stream has been shutdown" ) ) ?
764
760
. ecx . set ( cx) ;
765
-
761
+
766
762
let result = match unsafe { ssl_read ( ( & mut * self ) . into ( ) , buf. initialize_unfilled ( ) . as_mut_ptr ( ) , buf. initialize_unfilled ( ) . len ( ) ) . into_result ( ) } {
767
763
Err ( Error :: SslPeerCloseNotify ) => Poll :: Ready ( Ok ( ( ) ) ) ,
768
764
Err ( Error :: SslWantRead ) => Poll :: Pending ,
@@ -798,10 +794,10 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for Context<IoAdapter<T>> {
798
794
io. write_tracker . adjust_buf ( buf)
799
795
} ?;
800
796
801
-
797
+
802
798
self . io_mut ( ) . ok_or ( IoError :: new ( IoErrorKind :: Other , "stream has been shutdown" ) ) ?
803
799
. ecx . set ( cx) ;
804
-
800
+
805
801
let result = match unsafe { ssl_write ( ( & mut * self ) . into ( ) , buf. as_ptr ( ) , buf. len ( ) ) . into_result ( ) } {
806
802
Err ( Error :: SslPeerCloseNotify ) => Poll :: Ready ( Ok ( 0 ) ) ,
807
803
Err ( Error :: SslWantWrite ) => Poll :: Pending ,
@@ -868,7 +864,7 @@ mod tests {
868
864
869
865
use crate :: ssl:: context:: { HandshakeContext , Context } ;
870
866
use crate :: tests:: TestTrait ;
871
-
867
+
872
868
#[ test]
873
869
fn handshakecontext_sync ( ) {
874
870
assert ! ( !TestTrait :: <dyn Sync , HandshakeContext >:: new( ) . impls_trait( ) , "HandshakeContext must be !Sync" ) ;
@@ -884,7 +880,7 @@ mod tests {
884
880
unimplemented ! ( )
885
881
}
886
882
}
887
-
883
+
888
884
#[ cfg( feature = "std" ) ]
889
885
impl Write for NonSendStream {
890
886
fn write ( & mut self , _: & [ u8 ] ) -> IoResult < usize > {
@@ -906,7 +902,7 @@ mod tests {
906
902
unimplemented ! ( )
907
903
}
908
904
}
909
-
905
+
910
906
#[ cfg( feature = "std" ) ]
911
907
impl Write for SendStream {
912
908
fn write ( & mut self , _: & [ u8 ] ) -> IoResult < usize > {
0 commit comments