Skip to content

Commit 7ecb275

Browse files
committed
std::thread::scope() improvements
- avoids calling to Thread::unpark() opportunistically. - avoids calling thread::current() opportunistically. - properly pins ScopeData to indicate its on-stack nature. - outlines fast and slow paths for inc/dec/wait. - specializes on ThreadSanitizer for fences. - uses atomic intrinsics to avoid dangling shared references.
1 parent ca1e68b commit 7ecb275

File tree

3 files changed

+165
-44
lines changed

3 files changed

+165
-44
lines changed

library/std/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@
287287
#![feature(prelude_2024)]
288288
#![feature(ptr_as_uninit)]
289289
#![feature(raw_os_nonzero)]
290+
#![feature(cfg_sanitize)]
290291
#![feature(slice_internals)]
291292
#![feature(slice_ptr_get)]
292293
#![feature(std_internals)]

library/std/src/thread/mod.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ use crate::num::NonZeroUsize;
166166
use crate::panic;
167167
use crate::panicking;
168168
use crate::pin::Pin;
169-
use crate::ptr::addr_of_mut;
169+
use crate::ptr::{addr_of_mut, NonNull};
170170
use crate::str;
171171
use crate::sync::Arc;
172172
use crate::sys::thread as imp;
@@ -463,7 +463,7 @@ impl Builder {
463463
unsafe fn spawn_unchecked_<'a, 'scope, F, T>(
464464
self,
465465
f: F,
466-
scope_data: Option<Arc<scoped::ScopeData>>,
466+
scope_data: Option<Pin<&'scope scoped::ScopeData>>,
467467
) -> io::Result<JoinInner<'scope, T>>
468468
where
469469
F: FnOnce() -> T,
@@ -481,7 +481,7 @@ impl Builder {
481481
let their_thread = my_thread.clone();
482482

483483
let my_packet: Arc<Packet<'scope, T>> = Arc::new(Packet {
484-
scope: scope_data,
484+
scope: scope_data.map(|data| NonNull::from(data.get_ref())),
485485
result: UnsafeCell::new(None),
486486
_marker: PhantomData,
487487
});
@@ -511,8 +511,8 @@ impl Builder {
511511
unsafe { *their_packet.result.get() = Some(try_result) };
512512
};
513513

514-
if let Some(scope_data) = &my_packet.scope {
515-
scope_data.increment_num_running_threads();
514+
if let Some(scope_data) = my_packet.scope {
515+
unsafe { scope_data.as_ref().increment_num_running_threads() };
516516
}
517517

518518
Ok(JoinInner {
@@ -1302,7 +1302,7 @@ pub type Result<T> = crate::result::Result<T, Box<dyn Any + Send + 'static>>;
13021302
// An Arc to the packet is stored into a `JoinInner` which in turns is placed
13031303
// in `JoinHandle`.
13041304
struct Packet<'scope, T> {
1305-
scope: Option<Arc<scoped::ScopeData>>,
1305+
scope: Option<NonNull<scoped::ScopeData>>,
13061306
result: UnsafeCell<Option<Result<T>>>,
13071307
_marker: PhantomData<Option<&'scope scoped::ScopeData>>,
13081308
}
@@ -1335,12 +1335,23 @@ impl<'scope, T> Drop for Packet<'scope, T> {
13351335
rtabort!("thread result panicked on drop");
13361336
}
13371337
// Book-keeping so the scope knows when it's done.
1338-
if let Some(scope) = &self.scope {
1338+
if let Some(scope_data) = self.scope {
13391339
// Now that there will be no more user code running on this thread
13401340
// that can use 'scope, mark the thread as 'finished'.
13411341
// It's important we only do this after the `result` has been dropped,
13421342
// since dropping it might still use things it borrowed from 'scope.
1343-
scope.decrement_num_running_threads(unhandled_panic);
1343+
//
1344+
// A static method to decrement is used to keep `ScopeData` as a raw pointer.
1345+
// Using a reference risks the decrement function waking the `scope()` thread,
1346+
// invalidating our `ScopeData`, and leaving us with a dangling dereferenceable &ScopeData.
1347+
// This avoids issue #55005.
1348+
//
1349+
// SAFETY:
1350+
// Given the thread has been spawned,
1351+
// there was a matching call to `ScopeData::increment_num_running_threads()`.
1352+
unsafe {
1353+
scoped::ScopeData::decrement_num_running_threads(scope_data, unhandled_panic);
1354+
}
13441355
}
13451356
}
13461357
}

library/std/src/thread/scoped.rs

+145-36
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
use super::{current, park, Builder, JoinInner, Result, Thread};
2+
use crate::cell::UnsafeCell;
23
use crate::fmt;
34
use crate::io;
4-
use crate::marker::PhantomData;
5+
use crate::marker::{PhantomData, PhantomPinned};
56
use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
6-
use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7+
use crate::pin::Pin;
8+
use crate::ptr::NonNull;
9+
use crate::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
710
use crate::sync::Arc;
11+
use core::intrinsics::{atomic_store_rel, atomic_xsub_rel};
812

913
/// A scope to spawn scoped threads in.
1014
///
1115
/// See [`scope`] for details.
1216
#[stable(feature = "scoped_threads", since = "1.63.0")]
1317
pub struct Scope<'scope, 'env: 'scope> {
14-
data: Arc<ScopeData>,
18+
data: Pin<&'scope ScopeData>,
1519
/// Invariance over 'scope, to make sure 'scope cannot shrink,
1620
/// which is necessary for soundness.
1721
///
@@ -35,28 +39,133 @@ pub struct Scope<'scope, 'env: 'scope> {
3539
#[stable(feature = "scoped_threads", since = "1.63.0")]
3640
pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>);
3741

42+
const WAITING_BIT: usize = 1;
43+
const ONE_RUNNING: usize = 2;
44+
45+
/// Artificial limit on the maximum number of concurrently running threads in scope.
46+
/// This is used to preemptively avoid hitting an overflow condition in the running thread count.
47+
const MAX_RUNNING: usize = usize::MAX / 2;
48+
49+
#[derive(Default)]
3850
pub(super) struct ScopeData {
39-
num_running_threads: AtomicUsize,
40-
a_thread_panicked: AtomicBool,
41-
main_thread: Thread,
51+
sync_state: AtomicUsize,
52+
thread_panicked: AtomicBool,
53+
scope_thread: UnsafeCell<Option<Thread>>,
54+
_pinned: PhantomPinned,
4255
}
4356

57+
unsafe impl Send for ScopeData {} // SAFETY: ScopeData needs to be sent to the spawned threads in the scope.
58+
unsafe impl Sync for ScopeData {} // SAFETY: ScopeData is shared between the spawned threads and the scope thread.
59+
4460
impl ScopeData {
61+
/// Issues an Acquire fence which synchronizes with the `sync_state` Release sequence.
62+
fn fence_acquire_sync_state(&self) {
63+
// ThreadSanitizier doesn't properly support fences
64+
// so use an atomic load instead to avoid false positive data-race reports.
65+
if cfg!(sanitize = "thread") {
66+
self.sync_state.load(Ordering::Acquire);
67+
} else {
68+
fence(Ordering::Acquire);
69+
}
70+
}
71+
4572
pub(super) fn increment_num_running_threads(&self) {
46-
// We check for 'overflow' with usize::MAX / 2, to make sure there's no
47-
// chance it overflows to 0, which would result in unsoundness.
48-
if self.num_running_threads.fetch_add(1, Ordering::Relaxed) > usize::MAX / 2 {
49-
// This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles.
50-
self.decrement_num_running_threads(false);
51-
panic!("too many running threads in thread scope");
73+
// No need for any memory barriers as this is just incrementing the running count
74+
// with the assumption that the ScopeData remains valid before and after this call.
75+
let state = self.sync_state.fetch_add(ONE_RUNNING, Ordering::Relaxed);
76+
77+
// Make sure we're not spawning too many threads on the scope.
78+
// The `MAX_RUNNING` is intentionally lower than `usize::MAX` to detect overflow
79+
// conditions on the running count earlier, even in the presence of multiple threads.
80+
let running_threads = state / ONE_RUNNING;
81+
assert!(running_threads <= MAX_RUNNING, "too many running threads in thread scope");
82+
}
83+
84+
/// Decrement the number of running threads with the assumption that one was running before.
85+
/// Once the number of running threads becomes zero, it wakes up the scope thread if it's waiting.
86+
/// The running thread count hitting zero "happens before" the scope thread returns from waiting.
87+
///
88+
/// SAFETY:
89+
/// Caller must ensure that there was a matching call to increment_num_running_threadS() prior.
90+
pub(super) unsafe fn decrement_num_running_threads(data: NonNull<Self>, panicked: bool) {
91+
unsafe {
92+
if panicked {
93+
data.as_ref().thread_panicked.store(true, Ordering::Relaxed);
94+
}
95+
96+
// Decrement the running count with a Release barrier.
97+
// This ensures that all data accesses and side effects before the decrement
98+
// "happen before" the scope thread observes the running count to be zero.
99+
let state_ptr = data.as_ref().sync_state.as_mut_ptr();
100+
let state = atomic_xsub_rel(state_ptr, ONE_RUNNING);
101+
102+
let running_threads = state / ONE_RUNNING;
103+
assert_ne!(
104+
running_threads, 0,
105+
"decrement_num_running_threads called when not incremented"
106+
);
107+
108+
// Wake up the scope thread if it's waiting and if we're the last running thread.
109+
if state == (ONE_RUNNING | WAITING_BIT) {
110+
// Acquire barrier ensures that both the scope_thread store and WAITING_BIT set,
111+
// along with the data accesses and decrements from previous threads,
112+
// "happen before" we start to wake up the scope thread.
113+
data.as_ref().fence_acquire_sync_state();
114+
115+
let scope_thread = {
116+
let thread_ref = &mut *data.as_ref().scope_thread.get();
117+
thread_ref.take().expect("ScopeData has no thread even when WAITING_BIT is set")
118+
};
119+
120+
// Wake up the scope thread by removing the WAITING_BIT and unparking the thread.
121+
// Release barrier ensures the consume of `scope_thread` "happens before" the
122+
// waiting scope thread observes 0 and returns to invalidate our data pointer.
123+
atomic_store_rel(state_ptr, 0);
124+
scope_thread.unpark();
125+
}
52126
}
53127
}
54-
pub(super) fn decrement_num_running_threads(&self, panic: bool) {
55-
if panic {
56-
self.a_thread_panicked.store(true, Ordering::Relaxed);
128+
129+
/// Blocks the callers thread until all running threads have called decrement_num_running_threads().
130+
///
131+
/// SAFETY:
132+
/// Caller must ensure that they're the sole scope_thread calling this function.
133+
/// There should also be no future calls to `increment_num_running_threads()` at this point.
134+
unsafe fn wait_for_running_threads(&self) {
135+
// Fast check to see if no threads are running.
136+
// Acquire barrier ensures the running thread count updates
137+
// and previous side effects on those threads "happen before" we observe 0 and return.
138+
if self.sync_state.load(Ordering::Acquire) == 0 {
139+
return;
57140
}
58-
if self.num_running_threads.fetch_sub(1, Ordering::Release) == 1 {
59-
self.main_thread.unpark();
141+
142+
// Register our Thread object to be unparked.
143+
unsafe {
144+
let thread_ref = &mut *self.scope_thread.get();
145+
let old_scope_thread = thread_ref.replace(current());
146+
assert!(old_scope_thread.is_none(), "multiple threads waiting on same ScopeData");
147+
}
148+
149+
// Set the WAITING_BIT on the state to indicate there's a waiter.
150+
// Uses `fetch_add` over `fetch_or` as the former compiles to accelerated instructions on modern CPUs.
151+
// Release barrier ensures Thread registration above "happens before" WAITING_BIT is observed by last running thread.
152+
let state = self.sync_state.fetch_add(WAITING_BIT, Ordering::Release);
153+
assert_eq!(state & WAITING_BIT, 0, "multiple threads waiting on same ScopeData");
154+
155+
// Don't wait if all running threads completed while we were trying to set the WAITING_BIT.
156+
// Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
157+
if state / ONE_RUNNING == 0 {
158+
self.fence_acquire_sync_state();
159+
return;
160+
}
161+
162+
// Block the thread until the last running thread sees the WAITING_BIT and resets the state to zero.
163+
// Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
164+
loop {
165+
park();
166+
if self.sync_state.load(Ordering::Acquire) == 0 {
167+
return;
168+
}
60169
}
61170
}
62171
}
@@ -130,30 +239,26 @@ pub fn scope<'env, F, T>(f: F) -> T
130239
where
131240
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
132241
{
133-
// We put the `ScopeData` into an `Arc` so that other threads can finish their
134-
// `decrement_num_running_threads` even after this function returns.
135-
let scope = Scope {
136-
data: Arc::new(ScopeData {
137-
num_running_threads: AtomicUsize::new(0),
138-
main_thread: current(),
139-
a_thread_panicked: AtomicBool::new(false),
140-
}),
141-
env: PhantomData,
142-
scope: PhantomData,
143-
};
242+
// We can store the ScopeData on the stack as we're careful about accessing it intrusively.
243+
let data = ScopeData::default();
244+
245+
// Make sure the store the ScopeData as Pinned to document in the type system
246+
// that it must remain valid until it is dropped at the end of this function.
247+
// SAFETY: the ScopeData is stored on the stack.
248+
let scope =
249+
Scope { data: unsafe { Pin::new_unchecked(&data) }, env: PhantomData, scope: PhantomData };
144250

145251
// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
146252
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));
147253

148254
// Wait until all the threads are finished.
149-
while scope.data.num_running_threads.load(Ordering::Acquire) != 0 {
150-
park();
151-
}
255+
// SAFETY: this is the only thread that calls ScopeData::wait_for_running_threads().
256+
unsafe { scope.data.wait_for_running_threads() };
152257

153258
// Throw any panic from `f`, or the return value of `f` if no thread panicked.
154259
match result {
155260
Err(e) => resume_unwind(e),
156-
Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => {
261+
Ok(_) if scope.data.thread_panicked.load(Ordering::Relaxed) => {
157262
panic!("a scoped thread panicked")
158263
}
159264
Ok(result) => result,
@@ -252,7 +357,7 @@ impl Builder {
252357
F: FnOnce() -> T + Send + 'scope,
253358
T: Send + 'scope,
254359
{
255-
Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(f, Some(scope.data.clone())) }?))
360+
Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(f, Some(scope.data)) }?))
256361
}
257362
}
258363

@@ -327,10 +432,14 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
327432
#[stable(feature = "scoped_threads", since = "1.63.0")]
328433
impl fmt::Debug for Scope<'_, '_> {
329434
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435+
let state = self.data.sync_state.load(Ordering::Relaxed);
436+
let num_running_threads = state / ONE_RUNNING;
437+
let main_thread_waiting = state & WAITING_BIT != 0;
438+
330439
f.debug_struct("Scope")
331-
.field("num_running_threads", &self.data.num_running_threads.load(Ordering::Relaxed))
332-
.field("a_thread_panicked", &self.data.a_thread_panicked.load(Ordering::Relaxed))
333-
.field("main_thread", &self.data.main_thread)
440+
.field("num_running_threads", &num_running_threads)
441+
.field("thread_panicked", &self.data.thread_panicked.load(Ordering::Relaxed))
442+
.field("main_thread_waiting", &main_thread_waiting)
334443
.finish_non_exhaustive()
335444
}
336445
}

0 commit comments

Comments
 (0)