Skip to content

Commit a346bf0

Browse files
committed
Check full in fold & find_any consume_iter, remove for filter*
These adaptors consume may many elements before deferring to a base folder's fullness checks, and so they need to be performed manually. For the `filter`s, there's no way to do it manually (rayon-rs#632), so the specialisations just have to be removed. For `fold` and `find_any` this can be done with a `take_while`. This extends the octillion tests to confirm this behaviour. This makes a program like the following slightly slower compared to the direct `consume_iter` without a check, but it's still faster than the non-specialized form. ``` extern crate test; extern crate rayon; use rayon::prelude::*; fn main() { let count = (0..std::u32::MAX) .into_par_iter() .map(test::black_box) .find_any(|_| test::black_box(false)); println!("{:?}", count); } ``` ``` $ hyperfine ./find-original ./find-no-check ./find-check Benchmark #1: ./find-original Time (mean ± σ): 627.6 ms ± 25.7 ms [User: 7.130 s, System: 0.014 s] Range (min … max): 588.4 ms … 656.4 ms 10 runs Benchmark #2: ./find-no-check Time (mean ± σ): 481.5 ms ± 10.8 ms [User: 5.415 s, System: 0.013 s] Range (min … max): 468.9 ms … 498.2 ms 10 runs Benchmark #3: ./find-check Time (mean ± σ): 562.3 ms ± 11.8 ms [User: 6.363 s, System: 0.013 s] Range (min … max): 542.5 ms … 578.2 ms 10 runs ``` (find-original = without specialization, find-no-check = custom `consume_iter` without `take_while`, find-check = this commit)
1 parent cabe301 commit a346bf0

File tree

5 files changed

+62
-30
lines changed

5 files changed

+62
-30
lines changed

src/iter/filter.rs

+3-8
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,9 @@ where
136136
}
137137
}
138138

139-
fn consume_iter<I>(mut self, iter: I) -> Self
140-
where
141-
I: IntoIterator<Item = T>
142-
{
143-
self.base = self.base.consume_iter(iter.into_iter().filter(self.filter_op));
144-
self
145-
}
146-
139+
// This cannot easily specialize `consume_iter` to be better than
140+
// the default, because that requires checking `self.base.full()`
141+
// during a call to `self.base.consume_iter()`. (#632)
147142

148143
fn complete(self) -> Self::Result {
149144
self.base.complete()

src/iter/filter_map.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,10 @@ where
139139
self
140140
}
141141
}
142-
fn consume_iter<I>(mut self, iter: I) -> Self
143-
where
144-
I: IntoIterator<Item = T>
145-
{
146-
self.base = self.base.consume_iter(iter.into_iter().filter_map(self.filter_op));
147-
self
148-
}
142+
143+
// This cannot easily specialize `consume_iter` to be better than
144+
// the default, because that requires checking `self.base.full()`
145+
// during a call to `self.base.consume_iter()`. (#632)
149146

150147
fn complete(self) -> C::Result {
151148
self.base.complete()

src/iter/find.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ where
8989
fn consume_iter<I>(mut self, iter: I) -> Self
9090
where I: IntoIterator<Item = T>
9191
{
92-
self.item = iter.into_iter().find(self.find_op);
92+
self.item = iter
93+
.into_iter()
94+
// stop iterating if another thread has found something
95+
.take_while(|_| !self.full())
96+
.find(self.find_op);
9397
if self.item.is_some() {
9498
self.found.store(true, Ordering::Relaxed)
9599
}

src/iter/fold.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,18 @@ where
141141
}
142142
}
143143

144-
fn consume_iter<I>(mut self, iter: I) -> Self
144+
fn consume_iter<I>(self, iter: I) -> Self
145145
where
146146
I: IntoIterator<Item = T>
147147
{
148-
self.item = iter.into_iter().fold(self.item, self.fold_op);
149-
self
148+
let base = self.base;
149+
let item = iter
150+
.into_iter()
151+
// stop iterating if another thread has finished
152+
.take_while(|_| !base.full())
153+
.fold(self.item, self.fold_op);
154+
155+
FoldFolder { base: base, item: item, fold_op: self.fold_op }
150156
}
151157

152158
fn complete(self) -> C::Result {

tests/octillion.rs

+41-11
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,57 @@ fn find_first_octillion_flat() {
4242
assert_eq!(x, Some(0));
4343
}
4444

45-
#[test]
46-
fn find_last_octillion() {
45+
fn two_threads<F: Send + FnOnce() -> R, R: Send>(f: F) -> R {
4746
// FIXME: If we don't use at least two threads, then we end up walking
4847
// through the entire iterator sequentially, without the benefit of any
4948
// short-circuiting. We probably don't want testing to wait that long. ;)
50-
// It would be nice if `find_last` could prioritize the later splits,
51-
// basically flipping the `join` args, without needing indexed `rev`.
52-
// (or could we have an unindexed `rev`?)
5349
let builder = rayon::ThreadPoolBuilder::new().num_threads(2);
5450
let pool = builder.build().unwrap();
5551

56-
let x = pool.install(|| octillion().find_last(|_| true));
52+
pool.install(f)
53+
}
54+
55+
#[test]
56+
fn find_last_octillion() {
57+
// It would be nice if `find_last` could prioritize the later splits,
58+
// basically flipping the `join` args, without needing indexed `rev`.
59+
// (or could we have an unindexed `rev`?)
60+
let x = two_threads(|| octillion().find_last(|_| true));
5761
assert_eq!(x, Some(OCTILLION - 1));
5862
}
5963

6064
#[test]
6165
fn find_last_octillion_flat() {
62-
// FIXME: Ditto, need two threads.
63-
let builder = rayon::ThreadPoolBuilder::new().num_threads(2);
64-
let pool = builder.build().unwrap();
65-
66-
let x = pool.install(|| octillion_flat().find_last(|_| true));
66+
let x = two_threads(|| octillion_flat().find_last(|_| true));
6767
assert_eq!(x, Some(OCTILLION - 1));
6868
}
69+
70+
#[test]
71+
fn find_any_octillion() {
72+
let x = two_threads(|| octillion().find_any(|x| *x > OCTILLION / 2));
73+
assert!(x.is_some());
74+
}
75+
76+
#[test]
77+
fn find_any_octillion_flat() {
78+
let x = two_threads(|| octillion_flat().find_any(|x| *x > OCTILLION / 2));
79+
assert!(x.is_some());
80+
}
81+
82+
#[test]
83+
fn filter_find_any_octillion() {
84+
let x = two_threads(|| octillion().filter(|x| *x > OCTILLION / 2).find_any(|_| true));
85+
assert!(x.is_some());
86+
}
87+
88+
#[test]
89+
fn filter_find_any_octillion_flat() {
90+
let x = two_threads(|| octillion_flat().filter(|x| *x > OCTILLION / 2).find_any(|_| true));
91+
assert!(x.is_some());
92+
}
93+
94+
#[test]
95+
fn fold_find_any_octillion_flat() {
96+
let x = two_threads(|| octillion_flat().fold(|| (), |_, _| ()).find_any(|_| true));
97+
assert!(x.is_some());
98+
}

0 commit comments

Comments
 (0)