11use super :: { current, park, Builder , JoinInner , Result , Thread } ;
2+ use crate :: cell:: UnsafeCell ;
23use crate :: fmt;
34use crate :: io;
4- use crate :: marker:: PhantomData ;
5+ use crate :: marker:: { PhantomData , PhantomPinned } ;
56use crate :: panic:: { catch_unwind, resume_unwind, AssertUnwindSafe } ;
6- use crate :: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
7+ use crate :: pin:: Pin ;
8+ use crate :: ptr:: NonNull ;
9+ use crate :: sync:: atomic:: { fence, AtomicBool , AtomicUsize , Ordering } ;
710use crate :: sync:: Arc ;
11+ use core:: intrinsics:: { atomic_store_rel, atomic_xsub_rel} ;
812
913/// A scope to spawn scoped threads in.
1014///
1115/// See [`scope`] for details.
1216#[ stable( feature = "scoped_threads" , since = "1.63.0" ) ]
1317pub struct Scope < ' scope , ' env : ' scope > {
14- data : Arc < ScopeData > ,
18+ data : Pin < & ' scope ScopeData > ,
1519 /// Invariance over 'scope, to make sure 'scope cannot shrink,
1620 /// which is necessary for soundness.
1721 ///
@@ -35,28 +39,133 @@ pub struct Scope<'scope, 'env: 'scope> {
3539#[ stable( feature = "scoped_threads" , since = "1.63.0" ) ]
3640pub struct ScopedJoinHandle < ' scope , T > ( JoinInner < ' scope , T > ) ;
3741
42+ const WAITING_BIT : usize = 1 ;
43+ const ONE_RUNNING : usize = 2 ;
44+
45+ /// Artificial limit on the maximum number of concurrently running threads in scope.
46+ /// This is used to preemptively avoid hitting an overflow condition in the running thread count.
47+ const MAX_RUNNING : usize = usize:: MAX / 2 ;
48+
49+ #[ derive( Default ) ]
3850pub ( super ) struct ScopeData {
39- num_running_threads : AtomicUsize ,
40- a_thread_panicked : AtomicBool ,
41- main_thread : Thread ,
51+ sync_state : AtomicUsize ,
52+ thread_panicked : AtomicBool ,
53+ scope_thread : UnsafeCell < Option < Thread > > ,
54+ _pinned : PhantomPinned ,
4255}
4356
57+ unsafe impl Send for ScopeData { } // SAFETY: ScopeData needs to be sent to the spawned threads in the scope.
58+ unsafe impl Sync for ScopeData { } // SAFETY: ScopeData is shared between the spawned threads and the scope thread.
59+
4460impl ScopeData {
61+ /// Issues an Acquire fence which synchronizes with the `sync_state` Release sequence.
62+ fn fence_acquire_sync_state ( & self ) {
63+ // ThreadSanitizier doesn't properly support fences
64+ // so use an atomic load instead to avoid false positive data-race reports.
65+ if cfg ! ( sanitize = "thread" ) {
66+ self . sync_state . load ( Ordering :: Acquire ) ;
67+ } else {
68+ fence ( Ordering :: Acquire ) ;
69+ }
70+ }
71+
4572 pub ( super ) fn increment_num_running_threads ( & self ) {
46- // We check for 'overflow' with usize::MAX / 2, to make sure there's no
47- // chance it overflows to 0, which would result in unsoundness.
48- if self . num_running_threads . fetch_add ( 1 , Ordering :: Relaxed ) > usize:: MAX / 2 {
49- // This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles.
50- self . decrement_num_running_threads ( false ) ;
51- panic ! ( "too many running threads in thread scope" ) ;
73+ // No need for any memory barriers as this is just incrementing the running count
74+ // with the assumption that the ScopeData remains valid before and after this call.
75+ let state = self . sync_state . fetch_add ( ONE_RUNNING , Ordering :: Relaxed ) ;
76+
77+ // Make sure we're not spawning too many threads on the scope.
78+ // The `MAX_RUNNING` is intentionally lower than `usize::MAX` to detect overflow
79+ // conditions on the running count earlier, even in the presence of multiple threads.
80+ let running_threads = state / ONE_RUNNING ;
81+ assert ! ( running_threads <= MAX_RUNNING , "too many running threads in thread scope" ) ;
82+ }
83+
84+ /// Decrement the number of running threads with the assumption that one was running before.
85+ /// Once the number of running threads becomes zero, it wakes up the scope thread if it's waiting.
86+ /// The running thread count hitting zero "happens before" the scope thread returns from waiting.
87+ ///
88+ /// SAFETY:
89+ /// Caller must ensure that there was a matching call to increment_num_running_threadS() prior.
90+ pub ( super ) unsafe fn decrement_num_running_threads ( data : NonNull < Self > , panicked : bool ) {
91+ unsafe {
92+ if panicked {
93+ data. as_ref ( ) . thread_panicked . store ( true , Ordering :: Relaxed ) ;
94+ }
95+
96+ // Decrement the running count with a Release barrier.
97+ // This ensures that all data accesses and side effects before the decrement
98+ // "happen before" the scope thread observes the running count to be zero.
99+ let state_ptr = data. as_ref ( ) . sync_state . as_mut_ptr ( ) ;
100+ let state = atomic_xsub_rel ( state_ptr, ONE_RUNNING ) ;
101+
102+ let running_threads = state / ONE_RUNNING ;
103+ assert_ne ! (
104+ running_threads, 0 ,
105+ "decrement_num_running_threads called when not incremented"
106+ ) ;
107+
108+ // Wake up the scope thread if it's waiting and if we're the last running thread.
109+ if state == ( ONE_RUNNING | WAITING_BIT ) {
110+ // Acquire barrier ensures that both the scope_thread store and WAITING_BIT set,
111+ // along with the data accesses and decrements from previous threads,
112+ // "happen before" we start to wake up the scope thread.
113+ data. as_ref ( ) . fence_acquire_sync_state ( ) ;
114+
115+ let scope_thread = {
116+ let thread_ref = & mut * data. as_ref ( ) . scope_thread . get ( ) ;
117+ thread_ref. take ( ) . expect ( "ScopeData has no thread even when WAITING_BIT is set" )
118+ } ;
119+
120+ // Wake up the scope thread by removing the WAITING_BIT and unparking the thread.
121+ // Release barrier ensures the consume of `scope_thread` "happens before" the
122+ // waiting scope thread observes 0 and returns to invalidate our data pointer.
123+ atomic_store_rel ( state_ptr, 0 ) ;
124+ scope_thread. unpark ( ) ;
125+ }
52126 }
53127 }
54- pub ( super ) fn decrement_num_running_threads ( & self , panic : bool ) {
55- if panic {
56- self . a_thread_panicked . store ( true , Ordering :: Relaxed ) ;
128+
129+ /// Blocks the callers thread until all running threads have called decrement_num_running_threads().
130+ ///
131+ /// SAFETY:
132+ /// Caller must ensure that they're the sole scope_thread calling this function.
133+ /// There should also be no future calls to `increment_num_running_threads()` at this point.
134+ unsafe fn wait_for_running_threads ( & self ) {
135+ // Fast check to see if no threads are running.
136+ // Acquire barrier ensures the running thread count updates
137+ // and previous side effects on those threads "happen before" we observe 0 and return.
138+ if self . sync_state . load ( Ordering :: Acquire ) == 0 {
139+ return ;
57140 }
58- if self . num_running_threads . fetch_sub ( 1 , Ordering :: Release ) == 1 {
59- self . main_thread . unpark ( ) ;
141+
142+ // Register our Thread object to be unparked.
143+ unsafe {
144+ let thread_ref = & mut * self . scope_thread . get ( ) ;
145+ let old_scope_thread = thread_ref. replace ( current ( ) ) ;
146+ assert ! ( old_scope_thread. is_none( ) , "multiple threads waiting on same ScopeData" ) ;
147+ }
148+
149+ // Set the WAITING_BIT on the state to indicate there's a waiter.
150+ // Uses `fetch_add` over `fetch_or` as the former compiles to accelerated instructions on modern CPUs.
151+ // Release barrier ensures Thread registration above "happens before" WAITING_BIT is observed by last running thread.
152+ let state = self . sync_state . fetch_add ( WAITING_BIT , Ordering :: Release ) ;
153+ assert_eq ! ( state & WAITING_BIT , 0 , "multiple threads waiting on same ScopeData" ) ;
154+
155+ // Don't wait if all running threads completed while we were trying to set the WAITING_BIT.
156+ // Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
157+ if state / ONE_RUNNING == 0 {
158+ self . fence_acquire_sync_state ( ) ;
159+ return ;
160+ }
161+
162+ // Block the thread until the last running thread sees the WAITING_BIT and resets the state to zero.
163+ // Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
164+ loop {
165+ park ( ) ;
166+ if self . sync_state . load ( Ordering :: Acquire ) == 0 {
167+ return ;
168+ }
60169 }
61170 }
62171}
@@ -130,30 +239,26 @@ pub fn scope<'env, F, T>(f: F) -> T
130239where
131240 F : for < ' scope > FnOnce ( & ' scope Scope < ' scope , ' env > ) -> T ,
132241{
133- // We put the `ScopeData` into an `Arc` so that other threads can finish their
134- // `decrement_num_running_threads` even after this function returns.
135- let scope = Scope {
136- data : Arc :: new ( ScopeData {
137- num_running_threads : AtomicUsize :: new ( 0 ) ,
138- main_thread : current ( ) ,
139- a_thread_panicked : AtomicBool :: new ( false ) ,
140- } ) ,
141- env : PhantomData ,
142- scope : PhantomData ,
143- } ;
242+ // We can store the ScopeData on the stack as we're careful about accessing it intrusively.
243+ let data = ScopeData :: default ( ) ;
244+
245+ // Make sure the store the ScopeData as Pinned to document in the type system
246+ // that it must remain valid until it is dropped at the end of this function.
247+ // SAFETY: the ScopeData is stored on the stack.
248+ let scope =
249+ Scope { data : unsafe { Pin :: new_unchecked ( & data) } , env : PhantomData , scope : PhantomData } ;
144250
145251 // Run `f`, but catch panics so we can make sure to wait for all the threads to join.
146252 let result = catch_unwind ( AssertUnwindSafe ( || f ( & scope) ) ) ;
147253
148254 // Wait until all the threads are finished.
149- while scope. data . num_running_threads . load ( Ordering :: Acquire ) != 0 {
150- park ( ) ;
151- }
255+ // SAFETY: this is the only thread that calls ScopeData::wait_for_running_threads().
256+ unsafe { scope. data . wait_for_running_threads ( ) } ;
152257
153258 // Throw any panic from `f`, or the return value of `f` if no thread panicked.
154259 match result {
155260 Err ( e) => resume_unwind ( e) ,
156- Ok ( _) if scope. data . a_thread_panicked . load ( Ordering :: Relaxed ) => {
261+ Ok ( _) if scope. data . thread_panicked . load ( Ordering :: Relaxed ) => {
157262 panic ! ( "a scoped thread panicked" )
158263 }
159264 Ok ( result) => result,
@@ -252,7 +357,7 @@ impl Builder {
252357 F : FnOnce ( ) -> T + Send + ' scope ,
253358 T : Send + ' scope ,
254359 {
255- Ok ( ScopedJoinHandle ( unsafe { self . spawn_unchecked_ ( f, Some ( scope. data . clone ( ) ) ) } ?) )
360+ Ok ( ScopedJoinHandle ( unsafe { self . spawn_unchecked_ ( f, Some ( scope. data ) ) } ?) )
256361 }
257362}
258363
@@ -327,10 +432,14 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
327432#[ stable( feature = "scoped_threads" , since = "1.63.0" ) ]
328433impl fmt:: Debug for Scope < ' _ , ' _ > {
329434 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
435+ let state = self . data . sync_state . load ( Ordering :: Relaxed ) ;
436+ let num_running_threads = state / ONE_RUNNING ;
437+ let main_thread_waiting = state & WAITING_BIT != 0 ;
438+
330439 f. debug_struct ( "Scope" )
331- . field ( "num_running_threads" , & self . data . num_running_threads . load ( Ordering :: Relaxed ) )
332- . field ( "a_thread_panicked " , & self . data . a_thread_panicked . load ( Ordering :: Relaxed ) )
333- . field ( "main_thread " , & self . data . main_thread )
440+ . field ( "num_running_threads" , & num_running_threads)
441+ . field ( "thread_panicked " , & self . data . thread_panicked . load ( Ordering :: Relaxed ) )
442+ . field ( "main_thread_waiting " , & main_thread_waiting )
334443 . finish_non_exhaustive ( )
335444 }
336445}
0 commit comments