1
1
use super :: { current, park, Builder , JoinInner , Result , Thread } ;
2
+ use crate :: cell:: UnsafeCell ;
2
3
use crate :: fmt;
3
4
use crate :: io;
4
- use crate :: marker:: PhantomData ;
5
+ use crate :: marker:: { PhantomData , PhantomPinned } ;
5
6
use crate :: panic:: { catch_unwind, resume_unwind, AssertUnwindSafe } ;
6
- use crate :: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
7
+ use crate :: pin:: Pin ;
8
+ use crate :: ptr:: NonNull ;
9
+ use crate :: sync:: atomic:: { fence, AtomicBool , AtomicUsize , Ordering } ;
7
10
use crate :: sync:: Arc ;
11
+ use core:: intrinsics:: { atomic_store_rel, atomic_xsub_rel} ;
8
12
9
13
/// A scope to spawn scoped threads in.
10
14
///
11
15
/// See [`scope`] for details.
12
16
#[ stable( feature = "scoped_threads" , since = "1.63.0" ) ]
13
17
pub struct Scope < ' scope , ' env : ' scope > {
14
- data : Arc < ScopeData > ,
18
+ data : Pin < & ' scope ScopeData > ,
15
19
/// Invariance over 'scope, to make sure 'scope cannot shrink,
16
20
/// which is necessary for soundness.
17
21
///
@@ -35,28 +39,133 @@ pub struct Scope<'scope, 'env: 'scope> {
35
39
#[ stable( feature = "scoped_threads" , since = "1.63.0" ) ]
36
40
pub struct ScopedJoinHandle < ' scope , T > ( JoinInner < ' scope , T > ) ;
37
41
42
+ const WAITING_BIT : usize = 1 ;
43
+ const ONE_RUNNING : usize = 2 ;
44
+
45
+ /// Artificial limit on the maximum number of concurrently running threads in scope.
46
+ /// This is used to preemptively avoid hitting an overflow condition in the running thread count.
47
+ const MAX_RUNNING : usize = usize:: MAX / 2 ;
48
+
49
+ #[ derive( Default ) ]
38
50
pub ( super ) struct ScopeData {
39
- num_running_threads : AtomicUsize ,
40
- a_thread_panicked : AtomicBool ,
41
- main_thread : Thread ,
51
+ sync_state : AtomicUsize ,
52
+ thread_panicked : AtomicBool ,
53
+ scope_thread : UnsafeCell < Option < Thread > > ,
54
+ _pinned : PhantomPinned ,
42
55
}
43
56
57
+ unsafe impl Send for ScopeData { } // SAFETY: ScopeData needs to be sent to the spawned threads in the scope.
58
+ unsafe impl Sync for ScopeData { } // SAFETY: ScopeData is shared between the spawned threads and the scope thread.
59
+
44
60
impl ScopeData {
61
+ /// Issues an Acquire fence which synchronizes with the `sync_state` Release sequence.
62
+ fn fence_acquire_sync_state ( & self ) {
63
+ // ThreadSanitizier doesn't properly support fences
64
+ // so use an atomic load instead to avoid false positive data-race reports.
65
+ if cfg ! ( sanitize = "thread" ) {
66
+ self . sync_state . load ( Ordering :: Acquire ) ;
67
+ } else {
68
+ fence ( Ordering :: Acquire ) ;
69
+ }
70
+ }
71
+
45
72
pub ( super ) fn increment_num_running_threads ( & self ) {
46
- // We check for 'overflow' with usize::MAX / 2, to make sure there's no
47
- // chance it overflows to 0, which would result in unsoundness.
48
- if self . num_running_threads . fetch_add ( 1 , Ordering :: Relaxed ) > usize:: MAX / 2 {
49
- // This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles.
50
- self . decrement_num_running_threads ( false ) ;
51
- panic ! ( "too many running threads in thread scope" ) ;
73
+ // No need for any memory barriers as this is just incrementing the running count
74
+ // with the assumption that the ScopeData remains valid before and after this call.
75
+ let state = self . sync_state . fetch_add ( ONE_RUNNING , Ordering :: Relaxed ) ;
76
+
77
+ // Make sure we're not spawning too many threads on the scope.
78
+ // The `MAX_RUNNING` is intentionally lower than `usize::MAX` to detect overflow
79
+ // conditions on the running count earlier, even in the presence of multiple threads.
80
+ let running_threads = state / ONE_RUNNING ;
81
+ assert ! ( running_threads <= MAX_RUNNING , "too many running threads in thread scope" ) ;
82
+ }
83
+
84
+ /// Decrement the number of running threads with the assumption that one was running before.
85
+ /// Once the number of running threads becomes zero, it wakes up the scope thread if it's waiting.
86
+ /// The running thread count hitting zero "happens before" the scope thread returns from waiting.
87
+ ///
88
+ /// SAFETY:
89
+ /// Caller must ensure that there was a matching call to increment_num_running_threadS() prior.
90
+ pub ( super ) unsafe fn decrement_num_running_threads ( data : NonNull < Self > , panicked : bool ) {
91
+ unsafe {
92
+ if panicked {
93
+ data. as_ref ( ) . thread_panicked . store ( true , Ordering :: Relaxed ) ;
94
+ }
95
+
96
+ // Decrement the running count with a Release barrier.
97
+ // This ensures that all data accesses and side effects before the decrement
98
+ // "happen before" the scope thread observes the running count to be zero.
99
+ let state_ptr = data. as_ref ( ) . sync_state . as_mut_ptr ( ) ;
100
+ let state = atomic_xsub_rel ( state_ptr, ONE_RUNNING ) ;
101
+
102
+ let running_threads = state / ONE_RUNNING ;
103
+ assert_ne ! (
104
+ running_threads, 0 ,
105
+ "decrement_num_running_threads called when not incremented"
106
+ ) ;
107
+
108
+ // Wake up the scope thread if it's waiting and if we're the last running thread.
109
+ if state == ( ONE_RUNNING | WAITING_BIT ) {
110
+ // Acquire barrier ensures that both the scope_thread store and WAITING_BIT set,
111
+ // along with the data accesses and decrements from previous threads,
112
+ // "happen before" we start to wake up the scope thread.
113
+ data. as_ref ( ) . fence_acquire_sync_state ( ) ;
114
+
115
+ let scope_thread = {
116
+ let thread_ref = & mut * data. as_ref ( ) . scope_thread . get ( ) ;
117
+ thread_ref. take ( ) . expect ( "ScopeData has no thread even when WAITING_BIT is set" )
118
+ } ;
119
+
120
+ // Wake up the scope thread by removing the WAITING_BIT and unparking the thread.
121
+ // Release barrier ensures the consume of `scope_thread` "happens before" the
122
+ // waiting scope thread observes 0 and returns to invalidate our data pointer.
123
+ atomic_store_rel ( state_ptr, 0 ) ;
124
+ scope_thread. unpark ( ) ;
125
+ }
52
126
}
53
127
}
54
- pub ( super ) fn decrement_num_running_threads ( & self , panic : bool ) {
55
- if panic {
56
- self . a_thread_panicked . store ( true , Ordering :: Relaxed ) ;
128
+
129
+ /// Blocks the callers thread until all running threads have called decrement_num_running_threads().
130
+ ///
131
+ /// SAFETY:
132
+ /// Caller must ensure that they're the sole scope_thread calling this function.
133
+ /// There should also be no future calls to `increment_num_running_threads()` at this point.
134
+ unsafe fn wait_for_running_threads ( & self ) {
135
+ // Fast check to see if no threads are running.
136
+ // Acquire barrier ensures the running thread count updates
137
+ // and previous side effects on those threads "happen before" we observe 0 and return.
138
+ if self . sync_state . load ( Ordering :: Acquire ) == 0 {
139
+ return ;
57
140
}
58
- if self . num_running_threads . fetch_sub ( 1 , Ordering :: Release ) == 1 {
59
- self . main_thread . unpark ( ) ;
141
+
142
+ // Register our Thread object to be unparked.
143
+ unsafe {
144
+ let thread_ref = & mut * self . scope_thread . get ( ) ;
145
+ let old_scope_thread = thread_ref. replace ( current ( ) ) ;
146
+ assert ! ( old_scope_thread. is_none( ) , "multiple threads waiting on same ScopeData" ) ;
147
+ }
148
+
149
+ // Set the WAITING_BIT on the state to indicate there's a waiter.
150
+ // Uses `fetch_add` over `fetch_or` as the former compiles to accelerated instructions on modern CPUs.
151
+ // Release barrier ensures Thread registration above "happens before" WAITING_BIT is observed by last running thread.
152
+ let state = self . sync_state . fetch_add ( WAITING_BIT , Ordering :: Release ) ;
153
+ assert_eq ! ( state & WAITING_BIT , 0 , "multiple threads waiting on same ScopeData" ) ;
154
+
155
+ // Don't wait if all running threads completed while we were trying to set the WAITING_BIT.
156
+ // Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
157
+ if state / ONE_RUNNING == 0 {
158
+ self . fence_acquire_sync_state ( ) ;
159
+ return ;
160
+ }
161
+
162
+ // Block the thread until the last running thread sees the WAITING_BIT and resets the state to zero.
163
+ // Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
164
+ loop {
165
+ park ( ) ;
166
+ if self . sync_state . load ( Ordering :: Acquire ) == 0 {
167
+ return ;
168
+ }
60
169
}
61
170
}
62
171
}
@@ -130,30 +239,26 @@ pub fn scope<'env, F, T>(f: F) -> T
130
239
where
131
240
F : for < ' scope > FnOnce ( & ' scope Scope < ' scope , ' env > ) -> T ,
132
241
{
133
- // We put the `ScopeData` into an `Arc` so that other threads can finish their
134
- // `decrement_num_running_threads` even after this function returns.
135
- let scope = Scope {
136
- data : Arc :: new ( ScopeData {
137
- num_running_threads : AtomicUsize :: new ( 0 ) ,
138
- main_thread : current ( ) ,
139
- a_thread_panicked : AtomicBool :: new ( false ) ,
140
- } ) ,
141
- env : PhantomData ,
142
- scope : PhantomData ,
143
- } ;
242
+ // We can store the ScopeData on the stack as we're careful about accessing it intrusively.
243
+ let data = ScopeData :: default ( ) ;
244
+
245
+ // Make sure the store the ScopeData as Pinned to document in the type system
246
+ // that it must remain valid until it is dropped at the end of this function.
247
+ // SAFETY: the ScopeData is stored on the stack.
248
+ let scope =
249
+ Scope { data : unsafe { Pin :: new_unchecked ( & data) } , env : PhantomData , scope : PhantomData } ;
144
250
145
251
// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
146
252
let result = catch_unwind ( AssertUnwindSafe ( || f ( & scope) ) ) ;
147
253
148
254
// Wait until all the threads are finished.
149
- while scope. data . num_running_threads . load ( Ordering :: Acquire ) != 0 {
150
- park ( ) ;
151
- }
255
+ // SAFETY: this is the only thread that calls ScopeData::wait_for_running_threads().
256
+ unsafe { scope. data . wait_for_running_threads ( ) } ;
152
257
153
258
// Throw any panic from `f`, or the return value of `f` if no thread panicked.
154
259
match result {
155
260
Err ( e) => resume_unwind ( e) ,
156
- Ok ( _) if scope. data . a_thread_panicked . load ( Ordering :: Relaxed ) => {
261
+ Ok ( _) if scope. data . thread_panicked . load ( Ordering :: Relaxed ) => {
157
262
panic ! ( "a scoped thread panicked" )
158
263
}
159
264
Ok ( result) => result,
@@ -252,7 +357,7 @@ impl Builder {
252
357
F : FnOnce ( ) -> T + Send + ' scope ,
253
358
T : Send + ' scope ,
254
359
{
255
- Ok ( ScopedJoinHandle ( unsafe { self . spawn_unchecked_ ( f, Some ( scope. data . clone ( ) ) ) } ?) )
360
+ Ok ( ScopedJoinHandle ( unsafe { self . spawn_unchecked_ ( f, Some ( scope. data ) ) } ?) )
256
361
}
257
362
}
258
363
@@ -327,10 +432,14 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
327
432
#[ stable( feature = "scoped_threads" , since = "1.63.0" ) ]
328
433
impl fmt:: Debug for Scope < ' _ , ' _ > {
329
434
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
435
+ let state = self . data . sync_state . load ( Ordering :: Relaxed ) ;
436
+ let num_running_threads = state / ONE_RUNNING ;
437
+ let main_thread_waiting = state & WAITING_BIT != 0 ;
438
+
330
439
f. debug_struct ( "Scope" )
331
- . field ( "num_running_threads" , & self . data . num_running_threads . load ( Ordering :: Relaxed ) )
332
- . field ( "a_thread_panicked " , & self . data . a_thread_panicked . load ( Ordering :: Relaxed ) )
333
- . field ( "main_thread " , & self . data . main_thread )
440
+ . field ( "num_running_threads" , & num_running_threads)
441
+ . field ( "thread_panicked " , & self . data . thread_panicked . load ( Ordering :: Relaxed ) )
442
+ . field ( "main_thread_waiting " , & main_thread_waiting )
334
443
. finish_non_exhaustive ( )
335
444
}
336
445
}
0 commit comments