Skip to content

Commit c57dfbd

Browse files
committed
Only work-steal in the main loop
1 parent 5fadf44 commit c57dfbd

File tree

18 files changed

+308
-132
lines changed

18 files changed

+308
-132
lines changed

rayon-core/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ name = "rayon_core"
2020
[dependencies]
2121
crossbeam-deque = "0.8.1"
2222
crossbeam-utils = "0.8.0"
23+
smallvec = "1.11.0"
2324

2425
[dev-dependencies]
2526
rand = "0.9"

rayon-core/src/broadcast/mod.rs

+21-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::latch::{CountLatch, LatchRef};
33
use crate::registry::{Registry, WorkerThread};
44
use std::fmt;
55
use std::marker::PhantomData;
6+
use std::sync::atomic::{AtomicBool, Ordering};
67
use std::sync::Arc;
78

89
mod test;
@@ -99,13 +100,22 @@ where
99100
OP: Fn(BroadcastContext<'_>) -> R + Sync,
100101
R: Send,
101102
{
103+
let current_thread = WorkerThread::current();
104+
let current_thread_addr = current_thread as usize;
105+
let started = &AtomicBool::new(false);
102106
let f = move |injected: bool| {
103107
debug_assert!(injected);
108+
109+
// Mark as started if we are on the thread that initiated the broadcast.
110+
if current_thread_addr == WorkerThread::current() as usize {
111+
started.store(true, Ordering::Relaxed);
112+
}
113+
104114
BroadcastContext::with(&op)
105115
};
106116

107117
let n_threads = registry.num_threads();
108-
let current_thread = WorkerThread::current().as_ref();
118+
let current_thread = current_thread.as_ref();
109119
let tlv = crate::tlv::get();
110120
let latch = CountLatch::with_count(n_threads, current_thread);
111121
let jobs: Vec<_> = (0..n_threads)
@@ -115,8 +125,16 @@ where
115125

116126
registry.inject_broadcast(job_refs);
117127

128+
let current_thread_job_id = current_thread
129+
.and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
130+
.map(|worker| jobs[worker.index].as_job_ref().id());
131+
118132
// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
119-
latch.wait(current_thread);
133+
latch.wait(
134+
current_thread,
135+
|| started.load(Ordering::Relaxed),
136+
|job| Some(job.id()) == current_thread_job_id,
137+
);
120138
jobs.into_iter().map(|job| job.into_result()).collect()
121139
}
122140

@@ -132,7 +150,7 @@ where
132150
{
133151
let job = ArcJob::new({
134152
let registry = Arc::clone(registry);
135-
move || {
153+
move |_| {
136154
registry.catch_unwind(|| BroadcastContext::with(&op));
137155
registry.terminate(); // (*) permit registry to terminate now
138156
}

rayon-core/src/broadcast/test.rs

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ fn spawn_broadcast_self() {
6464
}
6565

6666
#[test]
67+
#[ignore]
6768
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
6869
fn broadcast_mutual() {
6970
let count = AtomicUsize::new(0);
@@ -98,6 +99,7 @@ fn spawn_broadcast_mutual() {
9899
}
99100

100101
#[test]
102+
#[ignore]
101103
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
102104
fn broadcast_mutual_sleepy() {
103105
let count = AtomicUsize::new(0);

rayon-core/src/job.rs

+26-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ pub(super) trait Job {
2626
unsafe fn execute(this: *const ());
2727
}
2828

29+
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
30+
pub(super) struct JobRefId {
31+
pointer: usize,
32+
}
33+
2934
/// Effectively a Job trait object. Each JobRef **must** be executed
3035
/// exactly once, or else data may leak.
3136
///
@@ -54,11 +59,11 @@ impl JobRef {
5459
}
5560
}
5661

57-
/// Returns an opaque handle that can be saved and compared,
58-
/// without making `JobRef` itself `Copy + Eq`.
5962
#[inline]
60-
pub(super) fn id(&self) -> impl Eq {
61-
(self.pointer, self.execute_fn)
63+
pub(super) fn id(&self) -> JobRefId {
64+
JobRefId {
65+
pointer: self.pointer as usize,
66+
}
6267
}
6368

6469
#[inline]
@@ -102,8 +107,13 @@ where
102107
JobRef::new(self)
103108
}
104109

105-
pub(super) unsafe fn run_inline(self, stolen: bool) -> R {
106-
self.func.into_inner().unwrap()(stolen)
110+
pub(super) unsafe fn run_inline(&self, stolen: bool) {
111+
let func = (*self.func.get()).take().unwrap();
112+
(*self.result.get()) = match unwind::halt_unwinding(|| func(stolen)) {
113+
Ok(x) => JobResult::Ok(x),
114+
Err(x) => JobResult::Panic(x),
115+
};
116+
Latch::set(&self.latch);
107117
}
108118

109119
pub(super) unsafe fn into_result(self) -> R {
@@ -136,15 +146,15 @@ where
136146
/// (Probably `StackJob` should be refactored in a similar fashion.)
137147
pub(super) struct HeapJob<BODY>
138148
where
139-
BODY: FnOnce() + Send,
149+
BODY: FnOnce(JobRefId) + Send,
140150
{
141151
job: BODY,
142152
tlv: Tlv,
143153
}
144154

145155
impl<BODY> HeapJob<BODY>
146156
where
147-
BODY: FnOnce() + Send,
157+
BODY: FnOnce(JobRefId) + Send,
148158
{
149159
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
150160
Box::new(HeapJob { job, tlv })
@@ -168,27 +178,28 @@ where
168178

169179
impl<BODY> Job for HeapJob<BODY>
170180
where
171-
BODY: FnOnce() + Send,
181+
BODY: FnOnce(JobRefId) + Send,
172182
{
173183
unsafe fn execute(this: *const ()) {
184+
let pointer = this as usize;
174185
let this = Box::from_raw(this as *mut Self);
175186
tlv::set(this.tlv);
176-
(this.job)();
187+
(this.job)(JobRefId { pointer });
177188
}
178189
}
179190

180191
/// Represents a job stored in an `Arc` -- like `HeapJob`, but may
181192
/// be turned into multiple `JobRef`s and called multiple times.
182193
pub(super) struct ArcJob<BODY>
183194
where
184-
BODY: Fn() + Send + Sync,
195+
BODY: Fn(JobRefId) + Send + Sync,
185196
{
186197
job: BODY,
187198
}
188199

189200
impl<BODY> ArcJob<BODY>
190201
where
191-
BODY: Fn() + Send + Sync,
202+
BODY: Fn(JobRefId) + Send + Sync,
192203
{
193204
pub(super) fn new(job: BODY) -> Arc<Self> {
194205
Arc::new(ArcJob { job })
@@ -212,11 +223,12 @@ where
212223

213224
impl<BODY> Job for ArcJob<BODY>
214225
where
215-
BODY: Fn() + Send + Sync,
226+
BODY: Fn(JobRefId) + Send + Sync,
216227
{
217228
unsafe fn execute(this: *const ()) {
229+
let pointer = this as usize;
218230
let this = Arc::from_raw(this as *mut Self);
219-
(this.job)();
231+
(this.job)(JobRefId { pointer });
220232
}
221233
}
222234

rayon-core/src/join/mod.rs

+28-54
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
use crate::job::JobRef;
12
use crate::job::StackJob;
23
use crate::latch::SpinLatch;
3-
use crate::registry::{self, WorkerThread};
4-
use crate::tlv::{self, Tlv};
4+
use crate::registry;
5+
use crate::tlv;
56
use crate::unwind;
6-
use std::any::Any;
7+
use std::sync::atomic::{AtomicBool, Ordering};
78

89
use crate::FnContext;
910

@@ -135,68 +136,41 @@ where
135136
// Create virtual wrapper for task b; this all has to be
136137
// done here so that the stack frame can keep it all live
137138
// long enough.
138-
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
139+
let job_b_started = AtomicBool::new(false);
140+
let job_b = StackJob::new(
141+
tlv,
142+
|migrated| {
143+
job_b_started.store(true, Ordering::Relaxed);
144+
call_b(oper_b)(migrated)
145+
},
146+
SpinLatch::new(worker_thread),
147+
);
139148
let job_b_ref = job_b.as_job_ref();
140149
let job_b_id = job_b_ref.id();
141150
worker_thread.push(job_b_ref);
142151

143152
// Execute task a; hopefully b gets stolen in the meantime.
144153
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
145-
let result_a = match status_a {
146-
Ok(v) => v,
147-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
148-
};
149-
150-
// Now that task A has finished, try to pop job B from the
151-
// local stack. It may already have been popped by job A; it
152-
// may also have been stolen. There may also be some tasks
153-
// pushed on top of it in the stack, and we will have to pop
154-
// those off to get to it.
155-
while !job_b.latch.probe() {
156-
if let Some(job) = worker_thread.take_local_job() {
157-
if job_b_id == job.id() {
158-
// Found it! Let's run it.
159-
//
160-
// Note that this could panic, but it's ok if we unwind here.
161154

162-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
163-
tlv::set(tlv);
164-
165-
let result_b = job_b.run_inline(injected);
166-
return (result_a, result_b);
167-
} else {
168-
worker_thread.execute(job);
169-
}
170-
} else {
171-
// Local deque is empty. Time to steal from other
172-
// threads.
173-
worker_thread.wait_until(&job_b.latch);
174-
debug_assert!(job_b.latch.probe());
175-
break;
176-
}
177-
}
155+
// Wait for job B or execute it if it's in the local queue.
156+
worker_thread.wait_for_jobs::<_, false>(
157+
&job_b.latch,
158+
|| job_b_started.load(Ordering::Relaxed),
159+
|job| job.id() == job_b_id,
160+
|job: JobRef| {
161+
debug_assert_eq!(job.id(), job_b_id);
162+
job_b.run_inline(injected);
163+
},
164+
);
178165

179166
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
180167
tlv::set(tlv);
181168

169+
let result_a = match status_a {
170+
Ok(v) => v,
171+
Err(err) => unwind::resume_unwinding(err),
172+
};
173+
182174
(result_a, job_b.into_result())
183175
})
184176
}
185-
186-
/// If job A panics, we still cannot return until we are sure that job
187-
/// B is complete. This is because it may contain references into the
188-
/// enclosing stack frame(s).
189-
#[cold] // cold path
190-
unsafe fn join_recover_from_panic(
191-
worker_thread: &WorkerThread,
192-
job_b_latch: &SpinLatch<'_>,
193-
err: Box<dyn Any + Send>,
194-
tlv: Tlv,
195-
) -> ! {
196-
worker_thread.wait_until(job_b_latch);
197-
198-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
199-
tlv::set(tlv);
200-
201-
unwind::resume_unwinding(err)
202-
}

rayon-core/src/join/test.rs

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ fn join_context_both() {
9696
}
9797

9898
#[test]
99+
#[ignore]
99100
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
100101
fn join_context_neither() {
101102
// If we're already in a 1-thread pool, neither job should be stolen.

rayon-core/src/latch.rs

+10-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::Deref;
33
use std::sync::atomic::{AtomicUsize, Ordering};
44
use std::sync::{Arc, Condvar, Mutex};
55

6+
use crate::job::JobRef;
67
use crate::registry::{Registry, WorkerThread};
78

89
/// We define various kinds of latches, which are all a primitive signaling
@@ -176,11 +177,6 @@ impl<'r> SpinLatch<'r> {
176177
..SpinLatch::new(thread)
177178
}
178179
}
179-
180-
#[inline]
181-
pub(super) fn probe(&self) -> bool {
182-
self.core_latch.probe()
183-
}
184180
}
185181

186182
impl<'r> AsCoreLatch for SpinLatch<'r> {
@@ -385,7 +381,12 @@ impl CountLatch {
385381
debug_assert!(old_counter != 0);
386382
}
387383

388-
pub(super) fn wait(&self, owner: Option<&WorkerThread>) {
384+
pub(super) fn wait(
385+
&self,
386+
owner: Option<&WorkerThread>,
387+
all_jobs_started: impl FnMut() -> bool,
388+
is_job: impl FnMut(&JobRef) -> bool,
389+
) {
389390
match &self.kind {
390391
CountLatchKind::Stealing {
391392
latch,
@@ -395,7 +396,9 @@ impl CountLatch {
395396
let owner = owner.expect("owner thread");
396397
debug_assert_eq!(registry.id(), owner.registry().id());
397398
debug_assert_eq!(*worker_index, owner.index());
398-
owner.wait_until(latch);
399+
owner.wait_for_jobs::<_, true>(latch, all_jobs_started, is_job, |job| {
400+
owner.execute(job)
401+
});
399402
},
400403
CountLatchKind::Blocking { latch } => latch.wait(),
401404
}

0 commit comments

Comments
 (0)