@@ -9,7 +9,7 @@ use std::marker::PhantomData;
9
9
use std:: sync:: Arc ;
10
10
11
11
thread_local ! {
12
- static CURRENT_CONTEXT : RefCell <Context > = RefCell :: new( Context :: default ( ) ) ;
12
+ static CURRENT_CONTEXT : RefCell <ContextStack > = RefCell :: new( ContextStack :: default ( ) ) ;
13
13
}
14
14
15
15
/// An execution-scoped collection of values.
@@ -122,7 +122,7 @@ impl Context {
122
122
/// Note: This function will panic if you attempt to attach another context
123
123
/// while the current one is still borrowed.
124
124
pub fn map_current < T > ( f : impl FnOnce ( & Context ) -> T ) -> T {
125
- CURRENT_CONTEXT . with ( |cx| f ( & cx. borrow ( ) ) )
125
+ CURRENT_CONTEXT . with ( |cx| cx. borrow ( ) . map_current_cx ( f ) )
126
126
}
127
127
128
128
/// Returns a clone of the current thread's context with the given value.
@@ -298,12 +298,10 @@ impl Context {
298
298
/// assert_eq!(Context::current().get::<ValueA>(), None);
299
299
/// ```
300
300
pub fn attach ( self ) -> ContextGuard {
301
- let previous_cx = CURRENT_CONTEXT
302
- . try_with ( |current| current. replace ( self ) )
303
- . ok ( ) ;
301
+ let cx_id = CURRENT_CONTEXT . with ( |cx| cx. borrow_mut ( ) . push ( self ) ) ;
304
302
305
303
ContextGuard {
306
- previous_cx ,
304
+ cx_pos : cx_id ,
307
305
_marker : PhantomData ,
308
306
}
309
307
}
@@ -344,17 +342,19 @@ impl fmt::Debug for Context {
344
342
}
345
343
346
344
/// A guard that resets the current context to the prior context when dropped.
347
- #[ allow ( missing_debug_implementations ) ]
345
+ #[ derive ( Debug ) ]
348
346
pub struct ContextGuard {
349
- previous_cx : Option < Context > ,
350
- // ensure this type is !Send as it relies on thread locals
347
+ // The position of the context in the stack. This is used to pop the context.
348
+ cx_pos : u16 ,
349
+ // Ensure this type is !Send as it relies on thread locals
351
350
_marker : PhantomData < * const ( ) > ,
352
351
}
353
352
354
353
impl Drop for ContextGuard {
355
354
fn drop ( & mut self ) {
356
- if let Some ( previous_cx) = self . previous_cx . take ( ) {
357
- let _ = CURRENT_CONTEXT . try_with ( |current| current. replace ( previous_cx) ) ;
355
+ let id = self . cx_pos ;
356
+ if id > ContextStack :: BASE_POS && id < ContextStack :: MAX_POS {
357
+ CURRENT_CONTEXT . with ( |context_stack| context_stack. borrow_mut ( ) . pop_id ( id) ) ;
358
358
}
359
359
}
360
360
}
@@ -381,10 +381,107 @@ impl Hasher for IdHasher {
381
381
}
382
382
}
383
383
384
+ /// A stack for keeping track of the [`Context`] instances that have been attached
385
+ /// to a thread.
386
+ ///
387
+ /// The stack allows for popping of contexts by position, which is used to do out
388
+ /// of order dropping of [`ContextGuard`] instances. Only when the top of the
389
+ /// stack is popped, the topmost [`Context`] is actually restored.
390
+ ///
391
+ /// The stack relies on the fact that it is thread local and that the
392
+ /// [`ContextGuard`] instances that are constructed using it can't be shared with
393
+ /// other threads.
394
+ struct ContextStack {
395
+ /// This is the current [`Context`] that is active on this thread, and the top
396
+ /// of the [`ContextStack`]. It is always present, and if the `stack` is empty
397
+ /// it's an empty [`Context`].
398
+ ///
399
+ /// Having this here allows for fast access to the current [`Context`].
400
+ current_cx : Context ,
401
+ /// A `stack` of the other contexts that have been attached to the thread.
402
+ stack : Vec < Option < Context > > ,
403
+ /// Ensure this type is !Send as it relies on thread locals
404
+ _marker : PhantomData < * const ( ) > ,
405
+ }
406
+
407
+ impl ContextStack {
408
+ const BASE_POS : u16 = 0 ;
409
+ const MAX_POS : u16 = u16:: MAX ;
410
+ const INITIAL_CAPACITY : usize = 8 ;
411
+
412
+ #[ inline( always) ]
413
+ fn push ( & mut self , cx : Context ) -> u16 {
414
+ // The next id is the length of the `stack`, plus one since we have the
415
+ // top of the [`ContextStack`] as the `current_cx`.
416
+ let next_id = self . stack . len ( ) + 1 ;
417
+ if next_id < ContextStack :: MAX_POS . into ( ) {
418
+ let current_cx = std:: mem:: replace ( & mut self . current_cx , cx) ;
419
+ self . stack . push ( Some ( current_cx) ) ;
420
+ next_id as u16
421
+ } else {
422
+ // This is an overflow, log it and ignore it.
423
+ // TODO:ban add logging here
424
+ ContextStack :: MAX_POS
425
+ }
426
+ }
427
+
428
+ #[ inline( always) ]
429
+ fn pop_id ( & mut self , pos : u16 ) {
430
+ if pos == ContextStack :: BASE_POS || pos == ContextStack :: MAX_POS {
431
+ // The empty context is always at the bottom of the [`ContextStack`]
432
+ // and cannot be popped, and the overflow position is invalid, so do
433
+ // nothing.
434
+ return ;
435
+ }
436
+ let len: u16 = self . stack . len ( ) as u16 ;
437
+ // Are we at the top of the [`ContextStack`]?
438
+ if pos == len {
439
+ // Shrink the stack if possible to clear out any out of order pops.
440
+ while let Some ( None ) = self . stack . last ( ) {
441
+ _ = self . stack . pop ( ) ;
442
+ }
443
+ // Restore the previous context. This will always happen since the
444
+ // empty context is always at the bottom of the stack if the
445
+ // [`ContextStack`] is not empty.
446
+ if let Some ( Some ( next_cx) ) = self . stack . pop ( ) {
447
+ self . current_cx = next_cx;
448
+ }
449
+ } else {
450
+ // This is an out of order pop.
451
+ if pos >= len {
452
+ // This is an invalid id, ignore it.
453
+ return ;
454
+ }
455
+ // Clear out the entry at the given id.
456
+ _ = self . stack [ pos as usize ] . take ( ) ;
457
+ }
458
+ }
459
+
460
+ #[ inline( always) ]
461
+ fn map_current_cx < T > ( & self , f : impl FnOnce ( & Context ) -> T ) -> T {
462
+ f ( & self . current_cx )
463
+ }
464
+ }
465
+
466
+ impl Default for ContextStack {
467
+ fn default ( ) -> Self {
468
+ ContextStack {
469
+ current_cx : Context :: default ( ) ,
470
+ stack : Vec :: with_capacity ( ContextStack :: INITIAL_CAPACITY ) ,
471
+ _marker : PhantomData ,
472
+ }
473
+ }
474
+ }
475
+
384
476
#[ cfg( test) ]
385
477
mod tests {
386
478
use super :: * ;
387
479
480
+ #[ derive( Debug , PartialEq ) ]
481
+ struct ValueA ( & ' static str ) ;
482
+ #[ derive( Debug , PartialEq ) ]
483
+ struct ValueB ( u64 ) ;
484
+
388
485
#[ test]
389
486
fn context_immutable ( ) {
390
487
#[ derive( Debug , PartialEq ) ]
@@ -424,10 +521,6 @@ mod tests {
424
521
425
522
#[ test]
426
523
fn nested_contexts ( ) {
427
- #[ derive( Debug , PartialEq ) ]
428
- struct ValueA ( & ' static str ) ;
429
- #[ derive( Debug , PartialEq ) ]
430
- struct ValueB ( u64 ) ;
431
524
let _outer_guard = Context :: new ( ) . with_value ( ValueA ( "a" ) ) . attach ( ) ;
432
525
433
526
// Only value `a` is set
@@ -462,13 +555,7 @@ mod tests {
462
555
}
463
556
464
557
#[ test]
465
- #[ ignore = "overlapping contexts are not supported yet" ]
466
558
fn overlapping_contexts ( ) {
467
- #[ derive( Debug , PartialEq ) ]
468
- struct ValueA ( & ' static str ) ;
469
- #[ derive( Debug , PartialEq ) ]
470
- struct ValueB ( u64 ) ;
471
-
472
559
let outer_guard = Context :: new ( ) . with_value ( ValueA ( "a" ) ) . attach ( ) ;
473
560
474
561
// Only value `a` is set
@@ -502,4 +589,60 @@ mod tests {
502
589
assert_eq ! ( current. get:: <ValueA >( ) , None ) ;
503
590
assert_eq ! ( current. get:: <ValueB >( ) , None ) ;
504
591
}
592
+
593
+ #[ test]
594
+ fn too_many_contexts ( ) {
595
+ let mut guards: Vec < ContextGuard > = Vec :: with_capacity ( ContextStack :: MAX_POS as usize ) ;
596
+ let stack_max_pos = ContextStack :: MAX_POS as u64 ;
597
+ // Fill the stack up until the last position
598
+ for i in 1 ..stack_max_pos {
599
+ let cx_guard = Context :: current ( ) . with_value ( ValueB ( i) ) . attach ( ) ;
600
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( i) ) ) ;
601
+ assert_eq ! ( cx_guard. cx_pos, i as u16 ) ;
602
+ guards. push ( cx_guard) ;
603
+ }
604
+ // Let's overflow the stack a couple of times
605
+ for _ in 0 ..16 {
606
+ let cx_guard = Context :: current ( ) . with_value ( ValueA ( "overflow" ) ) . attach ( ) ;
607
+ assert_eq ! ( cx_guard. cx_pos, ContextStack :: MAX_POS ) ;
608
+ assert_eq ! ( Context :: current( ) . get:: <ValueA >( ) , None ) ;
609
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 1 ) ) ) ;
610
+ guards. push ( cx_guard) ;
611
+ }
612
+ // Drop the overflow contexts
613
+ for _ in 0 ..16 {
614
+ guards. pop ( ) ;
615
+ assert_eq ! ( Context :: current( ) . get:: <ValueA >( ) , None ) ;
616
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 1 ) ) ) ;
617
+ }
618
+ // Drop one more so we can add a new one
619
+ guards. pop ( ) ;
620
+ assert_eq ! ( Context :: current( ) . get:: <ValueA >( ) , None ) ;
621
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 2 ) ) ) ;
622
+ // Push a new context and see that it works
623
+ let cx_guard = Context :: current ( ) . with_value ( ValueA ( "last" ) ) . attach ( ) ;
624
+ assert_eq ! ( cx_guard. cx_pos, ContextStack :: MAX_POS - 1 ) ;
625
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueA ( "last" ) ) ) ;
626
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 2 ) ) ) ;
627
+ guards. push ( cx_guard) ;
628
+ // Let's overflow the stack a couple of times again
629
+ for _ in 0 ..16 {
630
+ let cx_guard = Context :: current ( ) . with_value ( ValueA ( "overflow" ) ) . attach ( ) ;
631
+ assert_eq ! ( cx_guard. cx_pos, ContextStack :: MAX_POS ) ;
632
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueA ( "last" ) ) ) ;
633
+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 2 ) ) ) ;
634
+ guards. push ( cx_guard) ;
635
+ }
636
+ }
637
+
638
+ #[ test]
639
+ fn context_stack_pop_id ( ) {
640
+ // This is to get full line coverage of the `pop_id` function.
641
+ // In real life the `Drop`` implementation of `ContextGuard` ensures that
642
+ // the ids are valid and inside the bounds.
643
+ let mut stack = ContextStack :: default ( ) ;
644
+ stack. pop_id ( ContextStack :: BASE_POS ) ;
645
+ stack. pop_id ( ContextStack :: MAX_POS ) ;
646
+ stack. pop_id ( 4711 ) ;
647
+ }
505
648
}
0 commit comments