Skip to content

Commit 1e2cf3c

Browse files
committed
Add callbacks for when threads start and stop doing work
1 parent 9be36f8 commit 1e2cf3c

File tree

4 files changed

+107
-25
lines changed

4 files changed

+107
-25
lines changed

rayon-core/src/lib.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ pub struct ThreadPoolBuilder<S = DefaultSpawn> {
199199
/// Closure invoked to spawn threads.
200200
spawn_handler: S,
201201

202+
/// Closure invoked when starting computations in a thread.
203+
acquire_thread_handler: Option<Box<AcquireThreadHandler>>,
204+
205+
/// Closure invoked when blocking in a thread.
206+
release_thread_handler: Option<Box<ReleaseThreadHandler>>,
207+
202208
/// If false, worker threads will execute spawned jobs in a
203209
/// "depth-first" fashion. If true, they will do a "breadth-first"
204210
/// fashion. Depth-first is the default.
@@ -242,12 +248,22 @@ impl Default for ThreadPoolBuilder {
242248
start_handler: None,
243249
exit_handler: None,
244250
deadlock_handler: None,
251+
acquire_thread_handler: None,
252+
release_thread_handler: None,
245253
spawn_handler: DefaultSpawn,
246254
breadth_first: false,
247255
}
248256
}
249257
}
250258

259+
/// The type for a closure that gets invoked before starting computations in a thread.
260+
/// Note that this same closure may be invoked multiple times in parallel.
261+
type AcquireThreadHandler = dyn Fn() + Send + Sync;
262+
263+
/// The type for a closure that gets invoked before blocking in a thread.
264+
/// Note that this same closure may be invoked multiple times in parallel.
265+
type ReleaseThreadHandler = dyn Fn() + Send + Sync;
266+
251267
impl ThreadPoolBuilder {
252268
/// Creates and returns a valid rayon thread pool builder, but does not initialize it.
253269
pub fn new() -> Self {
@@ -348,7 +364,12 @@ impl ThreadPoolBuilder {
348364
Ok(())
349365
})
350366
.build()?;
351-
Ok(with_pool(&pool))
367+
let result = unwind::halt_unwinding(|| with_pool(&pool));
368+
pool.wait_until_stopped();
369+
match result {
370+
Ok(result) => Ok(result),
371+
Err(err) => unwind::resume_unwinding(err),
372+
}
352373
});
353374

354375
match result {
@@ -460,6 +481,8 @@ impl<S> ThreadPoolBuilder<S> {
460481
start_handler: self.start_handler,
461482
exit_handler: self.exit_handler,
462483
deadlock_handler: self.deadlock_handler,
484+
acquire_thread_handler: self.acquire_thread_handler,
485+
release_thread_handler: self.release_thread_handler,
463486
breadth_first: self.breadth_first,
464487
}
465488
}
@@ -618,6 +641,34 @@ impl<S> ThreadPoolBuilder<S> {
618641
self.breadth_first
619642
}
620643

644+
/// Takes the current acquire thread callback, leaving `None`.
645+
fn take_acquire_thread_handler(&mut self) -> Option<Box<AcquireThreadHandler>> {
646+
self.acquire_thread_handler.take()
647+
}
648+
649+
/// Set a callback to be invoked when starting computations in a thread.
650+
pub fn acquire_thread_handler<H>(mut self, acquire_thread_handler: H) -> Self
651+
where
652+
H: Fn() + Send + Sync + 'static,
653+
{
654+
self.acquire_thread_handler = Some(Box::new(acquire_thread_handler));
655+
self
656+
}
657+
658+
/// Takes the current release thread callback, leaving `None`.
659+
fn take_release_thread_handler(&mut self) -> Option<Box<ReleaseThreadHandler>> {
660+
self.release_thread_handler.take()
661+
}
662+
663+
/// Set a callback to be invoked when blocking in thread.
664+
pub fn release_thread_handler<H>(mut self, release_thread_handler: H) -> Self
665+
where
666+
H: Fn() + Send + Sync + 'static,
667+
{
668+
self.release_thread_handler = Some(Box::new(release_thread_handler));
669+
self
670+
}
671+
621672
/// Takes the current deadlock callback, leaving `None`.
622673
fn take_deadlock_handler(&mut self) -> Option<Box<DeadlockHandler>> {
623674
self.deadlock_handler.take()
@@ -801,6 +852,8 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
801852
ref deadlock_handler,
802853
ref start_handler,
803854
ref exit_handler,
855+
ref acquire_thread_handler,
856+
ref release_thread_handler,
804857
spawn_handler: _,
805858
ref breadth_first,
806859
} = *self;
@@ -818,6 +871,8 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
818871
let deadlock_handler = deadlock_handler.as_ref().map(|_| ClosurePlaceholder);
819872
let start_handler = start_handler.as_ref().map(|_| ClosurePlaceholder);
820873
let exit_handler = exit_handler.as_ref().map(|_| ClosurePlaceholder);
874+
let acquire_thread_handler = acquire_thread_handler.as_ref().map(|_| ClosurePlaceholder);
875+
let release_thread_handler = release_thread_handler.as_ref().map(|_| ClosurePlaceholder);
821876

822877
f.debug_struct("ThreadPoolBuilder")
823878
.field("num_threads", num_threads)
@@ -827,6 +882,8 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
827882
.field("deadlock_handler", &deadlock_handler)
828883
.field("start_handler", &start_handler)
829884
.field("exit_handler", &exit_handler)
885+
.field("acquire_thread_handler", &acquire_thread_handler)
886+
.field("release_thread_handler", &release_thread_handler)
830887
.field("breadth_first", &breadth_first)
831888
.finish()
832889
}

rayon-core/src/registry.rs

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use crate::sleep::Sleep;
66
use crate::tlv::Tlv;
77
use crate::unwind;
88
use crate::{
9-
DeadlockHandler, ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError,
10-
ThreadPoolBuilder, Yield,
9+
AcquireThreadHandler, DeadlockHandler, ErrorKind, ExitHandler, PanicHandler,
10+
ReleaseThreadHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, Yield,
1111
};
1212
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
1313
use std::cell::Cell;
@@ -137,9 +137,11 @@ pub struct Registry {
137137
injected_jobs: Injector<JobRef>,
138138
broadcasts: Mutex<Vec<Worker<JobRef>>>,
139139
panic_handler: Option<Box<PanicHandler>>,
140-
deadlock_handler: Option<Box<DeadlockHandler>>,
140+
pub(crate) deadlock_handler: Option<Box<DeadlockHandler>>,
141141
start_handler: Option<Box<StartHandler>>,
142142
exit_handler: Option<Box<ExitHandler>>,
143+
pub(crate) acquire_thread_handler: Option<Box<AcquireThreadHandler>>,
144+
pub(crate) release_thread_handler: Option<Box<ReleaseThreadHandler>>,
143145

144146
// When this latch reaches 0, it means that all work on this
145147
// registry must be complete. This is ensured in the following ways:
@@ -294,6 +296,8 @@ impl Registry {
294296
deadlock_handler: builder.take_deadlock_handler(),
295297
start_handler: builder.take_start_handler(),
296298
exit_handler: builder.take_exit_handler(),
299+
acquire_thread_handler: builder.take_acquire_thread_handler(),
300+
release_thread_handler: builder.take_release_thread_handler(),
297301
});
298302

299303
// If we return early or panic, make sure to terminate existing threads.
@@ -398,11 +402,24 @@ impl Registry {
398402

399403
/// Waits for the worker threads to stop. This is used for testing
400404
/// -- so we can check that termination actually works.
401-
#[cfg(test)]
402405
pub(super) fn wait_until_stopped(&self) {
406+
self.release_thread();
403407
for info in &self.thread_infos {
404408
info.stopped.wait();
405409
}
410+
self.acquire_thread();
411+
}
412+
413+
pub(crate) fn acquire_thread(&self) {
414+
if let Some(ref acquire_thread_handler) = self.acquire_thread_handler {
415+
acquire_thread_handler();
416+
}
417+
}
418+
419+
pub(crate) fn release_thread(&self) {
420+
if let Some(ref release_thread_handler) = self.release_thread_handler {
421+
release_thread_handler();
422+
}
406423
}
407424

408425
/// ////////////////////////////////////////////////////////////////////////
@@ -448,7 +465,7 @@ impl Registry {
448465
self.sleep.new_injected_jobs(usize::MAX, 1, queue_was_empty);
449466
}
450467

451-
fn has_injected_job(&self) -> bool {
468+
pub(crate) fn has_injected_job(&self) -> bool {
452469
!self.injected_jobs.is_empty()
453470
}
454471

@@ -547,7 +564,9 @@ impl Registry {
547564
LatchRef::new(l),
548565
);
549566
self.inject(job.as_job_ref());
567+
self.release_thread();
550568
job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.
569+
self.acquire_thread();
551570

552571
// flush accumulated logs as we exit the thread
553572
self.logger.log(|| Flush);
@@ -813,7 +832,7 @@ impl WorkerThread {
813832
}
814833
}
815834

816-
fn has_injected_job(&self) -> bool {
835+
pub(super) fn has_injected_job(&self) -> bool {
817836
!self.stealer.is_empty() || self.registry.has_injected_job()
818837
}
819838

@@ -843,12 +862,9 @@ impl WorkerThread {
843862
self.execute(job);
844863
idle_state = self.registry.sleep.start_looking(self.index, latch);
845864
} else {
846-
self.registry.sleep.no_work_found(
847-
&mut idle_state,
848-
latch,
849-
|| self.has_injected_job(),
850-
&self.registry.deadlock_handler,
851-
)
865+
self.registry
866+
.sleep
867+
.no_work_found(&mut idle_state, latch, &self)
852868
}
853869
}
854870

@@ -971,6 +987,7 @@ unsafe fn main_loop(thread: ThreadBuilder) {
971987
worker: index,
972988
terminate_addr: my_terminate_latch.as_core_latch().addr(),
973989
});
990+
registry.acquire_thread();
974991
worker_thread.wait_until(my_terminate_latch);
975992

976993
// Should not be any work left in our queue.
@@ -989,6 +1006,8 @@ unsafe fn main_loop(thread: ThreadBuilder) {
9891006
registry.catch_unwind(|| handler(index));
9901007
// We're already exiting the thread, there's nothing else to do.
9911008
}
1009+
1010+
registry.release_thread();
9921011
}
9931012

9941013
/// If already in a worker-thread, just execute `op`. Otherwise,

rayon-core/src/sleep/mod.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use crate::latch::CoreLatch;
55
use crate::log::Event::*;
66
use crate::log::Logger;
7+
use crate::registry::WorkerThread;
78
use crate::DeadlockHandler;
89
use crossbeam_utils::CachePadded;
910
use std::sync::atomic::Ordering;
@@ -160,8 +161,7 @@ impl Sleep {
160161
&self,
161162
idle_state: &mut IdleState,
162163
latch: &CoreLatch,
163-
has_injected_jobs: impl FnOnce() -> bool,
164-
deadlock_handler: &Option<Box<DeadlockHandler>>,
164+
thread: &WorkerThread,
165165
) {
166166
if idle_state.rounds < ROUNDS_UNTIL_SLEEPY {
167167
thread::yield_now();
@@ -175,7 +175,7 @@ impl Sleep {
175175
thread::yield_now();
176176
} else {
177177
debug_assert_eq!(idle_state.rounds, ROUNDS_UNTIL_SLEEPING);
178-
self.sleep(idle_state, latch, has_injected_jobs, deadlock_handler);
178+
self.sleep(idle_state, latch, thread);
179179
}
180180
}
181181

@@ -193,13 +193,7 @@ impl Sleep {
193193
}
194194

195195
#[cold]
196-
fn sleep(
197-
&self,
198-
idle_state: &mut IdleState,
199-
latch: &CoreLatch,
200-
has_injected_jobs: impl FnOnce() -> bool,
201-
deadlock_handler: &Option<Box<DeadlockHandler>>,
202-
) {
196+
fn sleep(&self, idle_state: &mut IdleState, latch: &CoreLatch, thread: &WorkerThread) {
203197
let worker_index = idle_state.worker_index;
204198

205199
if !latch.get_sleepy() {
@@ -266,7 +260,7 @@ impl Sleep {
266260
// - that job triggers the rollover over the JEC such that we don't see it
267261
// - we are the last active worker thread
268262
std::sync::atomic::fence(Ordering::SeqCst);
269-
if has_injected_jobs() {
263+
if thread.has_injected_job() {
270264
// If we see an externally injected job, then we have to 'wake
271265
// ourselves up'. (Ordinarily, `sub_sleeping_thread` is invoked by
272266
// the one that wakes us.)
@@ -276,7 +270,7 @@ impl Sleep {
276270
// Decrement the number of active threads and check for a deadlock
277271
let mut data = self.data.lock().unwrap();
278272
data.active_threads -= 1;
279-
data.deadlock_check(deadlock_handler);
273+
data.deadlock_check(&thread.registry.deadlock_handler);
280274
}
281275

282276
// If we don't see an injected job (the normal case), then flag
@@ -287,10 +281,16 @@ impl Sleep {
287281
// that whomever is coming to wake us will have to wait until we
288282
// release the mutex in the call to `wait`, so they will see this
289283
// boolean as true.)
284+
thread.registry.release_thread();
290285
*is_blocked = true;
291286
while *is_blocked {
292287
is_blocked = sleep_state.condvar.wait(is_blocked).unwrap();
293288
}
289+
290+
// Drop `is_blocked` now in case `acquire_thread` blocks
291+
drop(is_blocked);
292+
293+
thread.registry.acquire_thread();
294294
}
295295

296296
// Update other state:

rayon-core/src/thread_pool/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,12 @@ impl ThreadPool {
363363
let curr = self.registry.current_thread()?;
364364
Some(curr.yield_local())
365365
}
366+
367+
pub(crate) fn wait_until_stopped(self) {
368+
let registry = self.registry.clone();
369+
drop(self);
370+
registry.wait_until_stopped();
371+
}
366372
}
367373

368374
impl Drop for ThreadPool {

0 commit comments

Comments
 (0)