Skip to content

Specialize iter::Chain<A, B>::next when A==B #107701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion library/core/benches/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ fn bench_for_each_chain_ref_fold(b: &mut Bencher) {
/// Helper to benchmark `sum` for iterators taken by value which
/// can optimize `fold`, and by reference which cannot.
macro_rules! bench_sums {
($bench_sum:ident, $bench_ref_sum:ident, $iter:expr) => {
($bench_sum:ident, $bench_ref_sum:ident, $bench_ext_sum:ident, $iter:expr) => {
#[bench]
fn $bench_sum(b: &mut Bencher) {
b.iter(|| -> i64 { $iter.map(black_box).sum() });
Expand All @@ -211,151 +211,192 @@ macro_rules! bench_sums {
fn $bench_ref_sum(b: &mut Bencher) {
b.iter(|| -> i64 { $iter.map(black_box).by_ref().sum() });
}

#[bench]
fn $bench_ext_sum(b: &mut Bencher) {
b.iter(|| -> i64 {
let mut sum = 0;
for i in $iter.map(black_box) {
sum += i;
}
sum
});
}
};
}

bench_sums! {
bench_flat_map_sum,
bench_flat_map_ref_sum,
bench_flat_map_ext_sum,
(0i64..1000).flat_map(|x| x..x+1000)
}

bench_sums! {
bench_flat_map_chain_sum,
bench_flat_map_chain_ref_sum,
bench_flat_map_chain_ext_sum,
(0i64..1000000).flat_map(|x| once(x).chain(once(x)))
}

bench_sums! {
bench_enumerate_sum,
bench_enumerate_ref_sum,
bench_enumerate_ext_sum,
(0i64..1000000).enumerate().map(|(i, x)| x * i as i64)
}

bench_sums! {
bench_enumerate_chain_sum,
bench_enumerate_chain_ref_sum,
bench_enumerate_chain_ext_sum,
(0i64..1000000).chain(0..1000000).enumerate().map(|(i, x)| x * i as i64)
}

bench_sums! {
bench_filter_sum,
bench_filter_ref_sum,
bench_filter_ext_sum,
(0i64..1000000).filter(|x| x % 3 == 0)
}

bench_sums! {
bench_filter_chain_sum,
bench_filter_chain_ref_sum,
bench_filter_chain_ext_sum,
(0i64..1000000).chain(0..1000000).filter(|x| x % 3 == 0)
}

bench_sums! {
bench_filter_map_sum,
bench_filter_map_ref_sum,
bench_filter_map_ext_sum,
(0i64..1000000).filter_map(|x| x.checked_mul(x))
}

bench_sums! {
bench_filter_map_chain_sum,
bench_filter_map_chain_ref_sum,
bench_filter_map_chain_ext_sum,
(0i64..1000000).chain(0..1000000).filter_map(|x| x.checked_mul(x))
}

bench_sums! {
bench_fuse_sum,
bench_fuse_ref_sum,
bench_fuse_ext_sum,
(0i64..1000000).fuse()
}

bench_sums! {
bench_fuse_chain_sum,
bench_fuse_chain_ref_sum,
bench_fuse_chain_ext_sum,
(0i64..1000000).chain(0..1000000).fuse()
}

bench_sums! {
bench_inspect_sum,
bench_inspect_ref_sum,
bench_inspect_ext_sum,
(0i64..1000000).inspect(|_| {})
}

bench_sums! {
bench_inspect_chain_sum,
bench_inspect_chain_ref_sum,
bench_inspect_chain_ext_sum,
(0i64..1000000).chain(0..1000000).inspect(|_| {})
}

bench_sums! {
bench_peekable_sum,
bench_peekable_ref_sum,
bench_peekable_ext_sum,
(0i64..1000000).peekable()
}

bench_sums! {
bench_peekable_chain_sum,
bench_peekable_chain_ref_sum,
bench_peekable_chain_ext_sum,
(0i64..1000000).chain(0..1000000).peekable()
}

bench_sums! {
bench_skip_sum,
bench_skip_ref_sum,
bench_skip_ext_sum,
(0i64..1000000).skip(1000)
}

bench_sums! {
bench_skip_chain_sum,
bench_skip_chain_ref_sum,
bench_skip_chain_ext_sum,
(0i64..1000000).chain(0..1000000).skip(1000)
}

bench_sums! {
bench_skip_while_sum,
bench_skip_while_ref_sum,
bench_skip_while_ext_sum,
(0i64..1000000).skip_while(|&x| x < 1000)
}

bench_sums! {
bench_skip_while_chain_sum,
bench_skip_while_chain_ref_sum,
bench_skip_while_chain_ext_sum,
(0i64..1000000).chain(0..1000000).skip_while(|&x| x < 1000)
}

bench_sums! {
bench_take_while_chain_sum,
bench_take_while_chain_ref_sum,
bench_take_while_chain_ext_sum,
(0i64..1000000).chain(1000000..).take_while(|&x| x < 1111111)
}

bench_sums! {
bench_cycle_take_sum,
bench_cycle_take_ref_sum,
bench_cycle_take_ext_sum,
(0..10000).cycle().take(1000000)
}

bench_sums! {
bench_cycle_skip_take_sum,
bench_cycle_skip_take_ref_sum,
bench_cycle_skip_take_ext_sum,
(0..100000).cycle().skip(1000000).take(1000000)
}

bench_sums! {
bench_cycle_take_skip_sum,
bench_cycle_take_skip_ref_sum,
bench_cycle_take_skip_ext_sum,
(0..100000).cycle().take(1000000).skip(100000)
}

bench_sums! {
bench_skip_cycle_skip_zip_add_sum,
bench_skip_cycle_skip_zip_add_ref_sum,
bench_skip_cycle_skip_zip_add_ext_sum,
(0..100000).skip(100).cycle().skip(100)
.zip((0..100000).cycle().skip(10))
.map(|(a,b)| a+b)
.skip(100000)
.take(1000000)
}

bench_sums! {
bench_slice_chain_sum,
bench_slice_chain_ref_sum,
bench_slice_chain_ext_sum,
(&[0; 512]).iter().chain((&[1; 512]).iter())
}

// Checks whether Skip<Zip<A,B>> is as fast as Zip<Skip<A>, Skip<B>>, from
// https://users.rust-lang.org/t/performance-difference-between-iterator-zip-and-skip-order/15743
#[bench]
Expand Down
119 changes: 117 additions & 2 deletions library/core/src/iter/adapters/chain.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::iter::{DoubleEndedIterator, FusedIterator, Iterator, TrustedLen};
use crate::num::NonZeroUsize;
use crate::{mem, ptr};
use crate::ops::Try;

/// An iterator that links two iterators together, in a chain.
Expand Down Expand Up @@ -48,7 +49,7 @@ where

#[inline]
fn next(&mut self) -> Option<A::Item> {
and_then_or_clear(&mut self.a, Iterator::next).or_else(|| self.b.as_mut()?.next())
SpecChain::next(self)
}

#[inline]
Expand Down Expand Up @@ -178,7 +179,7 @@ where
{
#[inline]
fn next_back(&mut self) -> Option<A::Item> {
and_then_or_clear(&mut self.b, |b| b.next_back()).or_else(|| self.a.as_mut()?.next_back())
SpecChainBack::next_back(self)
}

#[inline]
Expand Down Expand Up @@ -303,3 +304,117 @@ fn and_then_or_clear<T, U>(opt: &mut Option<T>, f: impl FnOnce(&mut T) -> Option
}
x
}

/// Marks the two generic parameters of Chain as sufficiently equal that their values can be swapped
///
/// # Safety
///
/// This would be trivially safe if both types were identical, including lifetimes.
/// However we can't specify bounds like that and it would be overly restrictive since it's not
/// uncommon for borrowing iterators to have slightly different lifetimes.
///
/// We can relax this by only requiring that the base struct type is the same while ignoring
/// lifetime parameters as long as
/// * the actual runtime lifespan of the values is capped by the shorter of the two lifetimes
/// * all invoked trait methods (and drop code) monomorphize down to the same code
#[rustc_unsafe_specialization_marker]
unsafe trait SymmetricalModuloLifetimes {}

/// Safety:
/// * <A, A> ensures that the basic type is the same
/// * actual lifespan of the values is capped by the combined lifetime of Chain's fields as long as
/// there is no way to destructure Chain into. I.e. Chain must not implement `SourceIter`,
/// `into_parts(self)` or similar methods.
/// * we rely on the language currently having no mechanism that would allow lifetime-dependent
/// code paths. Specialization forbids `where T: 'static` and similar bounds (modulo the exposed
/// `#[rustc_unsafe_specialization_marker]` traits).
/// And any trait depending on `Any` would have to be 'static in *both* arms to make a useful Chain.
/// This is only true as long as *all* impls on `Chain` have the same bounds for A and B,
/// which currently is the case.
unsafe impl<A> SymmetricalModuloLifetimes for Chain<A, A> {}

trait SpecChain: Iterator {
fn next(&mut self) -> Option<Self::Item>;
}

trait SpecChainBack: DoubleEndedIterator {
fn next_back(&mut self) -> Option<Self::Item>;
}

impl<A, B> SpecChain for Chain<A, B>
where
A: Iterator,
B: Iterator<Item = A::Item>,
{
#[inline]
default fn next(&mut self) -> Option<A::Item> {
and_then_or_clear(&mut self.a, Iterator::next).or_else(|| self.b.as_mut()?.next())
}
}

impl<A, B> SpecChainBack for Chain<A, B>
where
A: DoubleEndedIterator,
B: DoubleEndedIterator<Item = A::Item>,
{
#[inline]
default fn next_back(&mut self) -> Option<Self::Item> {
and_then_or_clear(&mut self.b, |b| b.next_back()).or_else(|| self.a.as_mut()?.next_back())
}
}

impl<A, B> SpecChain for Chain<A, B>
where
A: Iterator + FusedIterator,
B: Iterator<Item = A::Item> + FusedIterator,
Self: SymmetricalModuloLifetimes,
{
#[inline]
fn next(&mut self) -> Option<A::Item> {
let mut result = self.a.as_mut().and_then( Iterator::next);
if result.is_none() {
if mem::needs_drop::<A>() {
// swap iters to avoid running drop code inside the loop.
// SAFETY: SymmetricalModuloLifetimes guarantees that A and B are safe to swap.
unsafe { mem::swap(&mut self.a, &mut *(&mut self.b as *mut _ as *mut Option<A>)) };
} else {
// SAFETY: SymmetricalModuloLifetimes guarantees that A and B are safe to swap.
// And they dont need drop, so we can overwrite the values directly.
unsafe {
ptr::write(&mut self.a, ptr::from_ref(&self.b).cast::<Option<A>>().read());
ptr::write(&mut self.b, None);
}
}
result = self.a.as_mut().and_then(Iterator::next);
}
result
}
}

impl<A, B> SpecChainBack for Chain<A, B>
where
A: DoubleEndedIterator + FusedIterator,
B: DoubleEndedIterator<Item = A::Item> + FusedIterator,
Self: SymmetricalModuloLifetimes,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
let mut result = self.b.as_mut().and_then( DoubleEndedIterator::next_back);
if result.is_none() {
if mem::needs_drop::<A>() {
// swap iters to avoid running drop code inside the loop.
// SAFETY: SymmetricalModuloLifetimes guarantees that A and B are safe to swap.
unsafe { mem::swap(&mut self.a, &mut *(&mut self.b as *mut _ as *mut Option<A>)) };
} else {
// SAFETY: SymmetricalModuloLifetimes guarantees that A and B are safe to swap.
// And they dont need drop, so we can overwrite the values directly.
unsafe {
ptr::write(&mut self.b, ptr::from_ref(&self.a).cast::<Option<B>>().read());
ptr::write(&mut self.a, None);
}
}
result = self.b.as_mut().and_then(DoubleEndedIterator::next_back);
}
result
}
}