Skip to content

std::thread::scope() improvements #98517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@
#![feature(prelude_2024)]
#![feature(ptr_as_uninit)]
#![feature(raw_os_nonzero)]
#![feature(cfg_sanitize)]
#![feature(slice_internals)]
#![feature(slice_ptr_get)]
#![feature(std_internals)]
Expand Down
27 changes: 19 additions & 8 deletions library/std/src/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ use crate::num::NonZeroUsize;
use crate::panic;
use crate::panicking;
use crate::pin::Pin;
use crate::ptr::addr_of_mut;
use crate::ptr::{addr_of_mut, NonNull};
use crate::str;
use crate::sync::Arc;
use crate::sys::thread as imp;
Expand Down Expand Up @@ -463,7 +463,7 @@ impl Builder {
unsafe fn spawn_unchecked_<'a, 'scope, F, T>(
self,
f: F,
scope_data: Option<Arc<scoped::ScopeData>>,
scope_data: Option<Pin<&'scope scoped::ScopeData>>,
) -> io::Result<JoinInner<'scope, T>>
where
F: FnOnce() -> T,
Expand All @@ -481,7 +481,7 @@ impl Builder {
let their_thread = my_thread.clone();

let my_packet: Arc<Packet<'scope, T>> = Arc::new(Packet {
scope: scope_data,
scope: scope_data.map(|data| NonNull::from(data.get_ref())),
result: UnsafeCell::new(None),
_marker: PhantomData,
});
Expand Down Expand Up @@ -511,8 +511,8 @@ impl Builder {
unsafe { *their_packet.result.get() = Some(try_result) };
};

if let Some(scope_data) = &my_packet.scope {
scope_data.increment_num_running_threads();
if let Some(scope_data) = my_packet.scope {
unsafe { scope_data.as_ref().increment_num_running_threads() };
}

Ok(JoinInner {
Expand Down Expand Up @@ -1302,7 +1302,7 @@ pub type Result<T> = crate::result::Result<T, Box<dyn Any + Send + 'static>>;
// An Arc to the packet is stored into a `JoinInner` which in turns is placed
// in `JoinHandle`.
struct Packet<'scope, T> {
scope: Option<Arc<scoped::ScopeData>>,
scope: Option<NonNull<scoped::ScopeData>>,
result: UnsafeCell<Option<Result<T>>>,
_marker: PhantomData<Option<&'scope scoped::ScopeData>>,
}
Expand Down Expand Up @@ -1335,12 +1335,23 @@ impl<'scope, T> Drop for Packet<'scope, T> {
rtabort!("thread result panicked on drop");
}
// Book-keeping so the scope knows when it's done.
if let Some(scope) = &self.scope {
if let Some(scope_data) = self.scope {
// Now that there will be no more user code running on this thread
// that can use 'scope, mark the thread as 'finished'.
// It's important we only do this after the `result` has been dropped,
// since dropping it might still use things it borrowed from 'scope.
scope.decrement_num_running_threads(unhandled_panic);
//
// A static method to decrement is used to keep `ScopeData` as a raw pointer.
// Using a reference risks the decrement function waking the `scope()` thread,
// invalidating our `ScopeData`, and leaving us with a dangling dereferenceable &ScopeData.
// This avoids issue #55005.
//
// SAFETY:
// Given the thread has been spawned,
// there was a matching call to `ScopeData::increment_num_running_threads()`.
unsafe {
scoped::ScopeData::decrement_num_running_threads(scope_data, unhandled_panic);
}
}
}
}
Expand Down
181 changes: 145 additions & 36 deletions library/std/src/thread/scoped.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use super::{current, park, Builder, JoinInner, Result, Thread};
use crate::cell::UnsafeCell;
use crate::fmt;
use crate::io;
use crate::marker::PhantomData;
use crate::marker::{PhantomData, PhantomPinned};
use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use crate::pin::Pin;
use crate::ptr::NonNull;
use crate::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
use crate::sync::Arc;
use core::intrinsics::{atomic_store_rel, atomic_xsub_rel};

/// A scope to spawn scoped threads in.
///
/// See [`scope`] for details.
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
data: Pin<&'scope ScopeData>,
/// Invariance over 'scope, to make sure 'scope cannot shrink,
/// which is necessary for soundness.
///
Expand All @@ -35,28 +39,133 @@ pub struct Scope<'scope, 'env: 'scope> {
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>);

const WAITING_BIT: usize = 1;
const ONE_RUNNING: usize = 2;

/// Artificial limit on the maximum number of concurrently running threads in scope.
/// This is used to preemptively avoid hitting an overflow condition in the running thread count.
const MAX_RUNNING: usize = usize::MAX / 2;

#[derive(Default)]
pub(super) struct ScopeData {
num_running_threads: AtomicUsize,
a_thread_panicked: AtomicBool,
main_thread: Thread,
sync_state: AtomicUsize,
thread_panicked: AtomicBool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the previous name makes more sense. This is true when any thread panicked. thread_panicked to me suggests that there is only a single thread.

scope_thread: UnsafeCell<Option<Thread>>,
_pinned: PhantomPinned,
}

unsafe impl Send for ScopeData {} // SAFETY: ScopeData needs to be sent to the spawned threads in the scope.
unsafe impl Sync for ScopeData {} // SAFETY: ScopeData is shared between the spawned threads and the scope thread.

impl ScopeData {
/// Issues an Acquire fence which synchronizes with the `sync_state` Release sequence.
fn fence_acquire_sync_state(&self) {
// ThreadSanitizier doesn't properly support fences
// so use an atomic load instead to avoid false positive data-race reports.
if cfg!(sanitize = "thread") {
self.sync_state.load(Ordering::Acquire);
} else {
fence(Ordering::Acquire);
}
}

pub(super) fn increment_num_running_threads(&self) {
// We check for 'overflow' with usize::MAX / 2, to make sure there's no
// chance it overflows to 0, which would result in unsoundness.
if self.num_running_threads.fetch_add(1, Ordering::Relaxed) > usize::MAX / 2 {
// This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles.
self.decrement_num_running_threads(false);
panic!("too many running threads in thread scope");
// No need for any memory barriers as this is just incrementing the running count
// with the assumption that the ScopeData remains valid before and after this call.
let state = self.sync_state.fetch_add(ONE_RUNNING, Ordering::Relaxed);

// Make sure we're not spawning too many threads on the scope.
// The `MAX_RUNNING` is intentionally lower than `usize::MAX` to detect overflow
// conditions on the running count earlier, even in the presence of multiple threads.
let running_threads = state / ONE_RUNNING;
assert!(running_threads <= MAX_RUNNING, "too many running threads in thread scope");
}

/// Decrement the number of running threads with the assumption that one was running before.
/// Once the number of running threads becomes zero, it wakes up the scope thread if it's waiting.
/// The running thread count hitting zero "happens before" the scope thread returns from waiting.
///
/// SAFETY:
/// Caller must ensure that there was a matching call to increment_num_running_threadS() prior.
pub(super) unsafe fn decrement_num_running_threads(data: NonNull<Self>, panicked: bool) {
unsafe {
if panicked {
data.as_ref().thread_panicked.store(true, Ordering::Relaxed);
}

// Decrement the running count with a Release barrier.
// This ensures that all data accesses and side effects before the decrement
// "happen before" the scope thread observes the running count to be zero.
let state_ptr = data.as_ref().sync_state.as_mut_ptr();
let state = atomic_xsub_rel(state_ptr, ONE_RUNNING);

let running_threads = state / ONE_RUNNING;
assert_ne!(
running_threads, 0,
"decrement_num_running_threads called when not incremented"
);

// Wake up the scope thread if it's waiting and if we're the last running thread.
if state == (ONE_RUNNING | WAITING_BIT) {
// Acquire barrier ensures that both the scope_thread store and WAITING_BIT set,
// along with the data accesses and decrements from previous threads,
// "happen before" we start to wake up the scope thread.
data.as_ref().fence_acquire_sync_state();

let scope_thread = {
let thread_ref = &mut *data.as_ref().scope_thread.get();
thread_ref.take().expect("ScopeData has no thread even when WAITING_BIT is set")
};

// Wake up the scope thread by removing the WAITING_BIT and unparking the thread.
// Release barrier ensures the consume of `scope_thread` "happens before" the
// waiting scope thread observes 0 and returns to invalidate our data pointer.
atomic_store_rel(state_ptr, 0);
scope_thread.unpark();
}
}
}
pub(super) fn decrement_num_running_threads(&self, panic: bool) {
if panic {
self.a_thread_panicked.store(true, Ordering::Relaxed);

/// Blocks the callers thread until all running threads have called decrement_num_running_threads().
///
/// SAFETY:
/// Caller must ensure that they're the sole scope_thread calling this function.
/// There should also be no future calls to `increment_num_running_threads()` at this point.
unsafe fn wait_for_running_threads(&self) {
// Fast check to see if no threads are running.
// Acquire barrier ensures the running thread count updates
// and previous side effects on those threads "happen before" we observe 0 and return.
if self.sync_state.load(Ordering::Acquire) == 0 {
return;
}
if self.num_running_threads.fetch_sub(1, Ordering::Release) == 1 {
self.main_thread.unpark();

// Register our Thread object to be unparked.
unsafe {
let thread_ref = &mut *self.scope_thread.get();
let old_scope_thread = thread_ref.replace(current());
assert!(old_scope_thread.is_none(), "multiple threads waiting on same ScopeData");
}

// Set the WAITING_BIT on the state to indicate there's a waiter.
// Uses `fetch_add` over `fetch_or` as the former compiles to accelerated instructions on modern CPUs.
// Release barrier ensures Thread registration above "happens before" WAITING_BIT is observed by last running thread.
let state = self.sync_state.fetch_add(WAITING_BIT, Ordering::Release);
assert_eq!(state & WAITING_BIT, 0, "multiple threads waiting on same ScopeData");

// Don't wait if all running threads completed while we were trying to set the WAITING_BIT.
// Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
if state / ONE_RUNNING == 0 {
self.fence_acquire_sync_state();
return;
}

// Block the thread until the last running thread sees the WAITING_BIT and resets the state to zero.
// Acquire barrier ensures all running thread count updates and related side effects "happen before" we return.
loop {
park();
if self.sync_state.load(Ordering::Acquire) == 0 {
return;
}
}
}
}
Expand Down Expand Up @@ -130,30 +239,26 @@ pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
// We put the `ScopeData` into an `Arc` so that other threads can finish their
// `decrement_num_running_threads` even after this function returns.
let scope = Scope {
data: Arc::new(ScopeData {
num_running_threads: AtomicUsize::new(0),
main_thread: current(),
a_thread_panicked: AtomicBool::new(false),
}),
env: PhantomData,
scope: PhantomData,
};
// We can store the ScopeData on the stack as we're careful about accessing it intrusively.
let data = ScopeData::default();

// Make sure the store the ScopeData as Pinned to document in the type system
// that it must remain valid until it is dropped at the end of this function.
// SAFETY: the ScopeData is stored on the stack.
let scope =
Scope { data: unsafe { Pin::new_unchecked(&data) }, env: PhantomData, scope: PhantomData };

// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));

// Wait until all the threads are finished.
while scope.data.num_running_threads.load(Ordering::Acquire) != 0 {
park();
}
// SAFETY: this is the only thread that calls ScopeData::wait_for_running_threads().
unsafe { scope.data.wait_for_running_threads() };

// Throw any panic from `f`, or the return value of `f` if no thread panicked.
match result {
Err(e) => resume_unwind(e),
Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => {
Ok(_) if scope.data.thread_panicked.load(Ordering::Relaxed) => {
panic!("a scoped thread panicked")
}
Ok(result) => result,
Expand Down Expand Up @@ -252,7 +357,7 @@ impl Builder {
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(f, Some(scope.data.clone())) }?))
Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(f, Some(scope.data)) }?))
}
}

Expand Down Expand Up @@ -327,10 +432,14 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
#[stable(feature = "scoped_threads", since = "1.63.0")]
impl fmt::Debug for Scope<'_, '_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.data.sync_state.load(Ordering::Relaxed);
let num_running_threads = state / ONE_RUNNING;
let main_thread_waiting = state & WAITING_BIT != 0;

f.debug_struct("Scope")
.field("num_running_threads", &self.data.num_running_threads.load(Ordering::Relaxed))
.field("a_thread_panicked", &self.data.a_thread_panicked.load(Ordering::Relaxed))
.field("main_thread", &self.data.main_thread)
.field("num_running_threads", &num_running_threads)
.field("thread_panicked", &self.data.thread_panicked.load(Ordering::Relaxed))
.field("main_thread_waiting", &main_thread_waiting)
.finish_non_exhaustive()
}
}
Expand Down