Skip to content

Commit 4976d96

Browse files
jturner314LukeMathWalker
authored andcommitted
Implement .nth_back() for iterators (#686)
* Implement .nth_back() for iterators The `.nth_back()` method was added to the `DoubleEndedIterator` trait in Rust 1.37. Providing an implementation for `Baseiter` and forwarding it for `Iter/Mut` improves performance. * Split test_nth_back into multiple tests
1 parent c916203 commit 4976d96

File tree

4 files changed

+114
-1
lines changed

4 files changed

+114
-1
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ sudo: required
44
dist: trusty
55
matrix:
66
include:
7-
- rust: 1.32.0
7+
- rust: 1.37.0
88
env:
99
- FEATURES='test docs'
1010
- RUSTFLAGS='-D warnings'

benches/iter.rs

+21
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,27 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) {
7474
bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::<f32>());
7575
}
7676

77+
#[bench]
78+
fn iter_rev_step_by_contiguous(bench: &mut Bencher) {
79+
let a = Array::linspace(0., 1., 512);
80+
bench.iter(|| {
81+
a.iter().rev().step_by(2).for_each(|x| {
82+
black_box(x);
83+
})
84+
});
85+
}
86+
87+
#[bench]
88+
fn iter_rev_step_by_discontiguous(bench: &mut Bencher) {
89+
let mut a = Array::linspace(0., 1., 1024);
90+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
91+
bench.iter(|| {
92+
a.iter().rev().step_by(2).for_each(|x| {
93+
black_box(x);
94+
})
95+
});
96+
}
97+
7798
const ZIPSZ: usize = 10_000;
7899

79100
#[bench]

src/iterators/mod.rs

+24
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,22 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1> {
131131
unsafe { Some(self.ptr.offset(offset)) }
132132
}
133133

134+
fn nth_back(&mut self, n: usize) -> Option<*mut A> {
135+
let index = self.index?;
136+
let len = self.dim[0] - index[0];
137+
if n < len {
138+
self.dim[0] -= n + 1;
139+
let offset = <_>::stride_offset(&self.dim, &self.strides);
140+
if index == self.dim {
141+
self.index = None;
142+
}
143+
unsafe { Some(self.ptr.offset(offset)) }
144+
} else {
145+
self.index = None;
146+
None
147+
}
148+
}
149+
134150
fn rfold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
135151
where
136152
G: FnMut(Acc, *mut A) -> Acc,
@@ -437,6 +453,10 @@ impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> {
437453
either_mut!(self.inner, iter => iter.next_back())
438454
}
439455

456+
fn nth_back(&mut self, n: usize) -> Option<&'a A> {
457+
either_mut!(self.inner, iter => iter.nth_back(n))
458+
}
459+
440460
fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
441461
where
442462
G: FnMut(Acc, Self::Item) -> Acc,
@@ -561,6 +581,10 @@ impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> {
561581
either_mut!(self.inner, iter => iter.next_back())
562582
}
563583

584+
fn nth_back(&mut self, n: usize) -> Option<&'a mut A> {
585+
either_mut!(self.inner, iter => iter.nth_back(n))
586+
}
587+
564588
fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
565589
where
566590
G: FnMut(Acc, Self::Item) -> Acc,

tests/iterators.rs

+68
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,74 @@ fn test_fold() {
776776
assert_eq!(a.iter().fold(0, |acc, &x| acc + x), 1);
777777
}
778778

779+
#[test]
780+
fn nth_back_examples() {
781+
let mut a: Array1<i32> = (0..256).collect();
782+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
783+
assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1]));
784+
assert_eq!(a.iter().nth_back(1), Some(&a[a.len() - 2]));
785+
assert_eq!(a.iter().nth_back(a.len() - 2), Some(&a[1]));
786+
assert_eq!(a.iter().nth_back(a.len() - 1), Some(&a[0]));
787+
assert_eq!(a.iter().nth_back(a.len()), None);
788+
assert_eq!(a.iter().nth_back(a.len() + 1), None);
789+
assert_eq!(a.iter().nth_back(a.len() + 2), None);
790+
}
791+
792+
#[test]
793+
fn nth_back_zero_n() {
794+
let mut a: Array1<i32> = (0..256).collect();
795+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
796+
let mut iter1 = a.iter();
797+
let mut iter2 = a.iter();
798+
for _ in 0..(a.len() + 1) {
799+
assert_eq!(iter1.nth_back(0), iter2.next_back());
800+
assert_eq!(iter1.len(), iter2.len());
801+
}
802+
}
803+
804+
#[test]
805+
fn nth_back_nonzero_n() {
806+
let mut a: Array1<i32> = (0..256).collect();
807+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
808+
let mut iter1 = a.iter();
809+
let mut iter2 = a.iter();
810+
for _ in 0..(a.len() / 3 + 1) {
811+
assert_eq!(iter1.nth_back(2), {
812+
iter2.next_back();
813+
iter2.next_back();
814+
iter2.next_back()
815+
});
816+
assert_eq!(iter1.len(), iter2.len());
817+
}
818+
}
819+
820+
#[test]
821+
fn nth_back_past_end() {
822+
let mut a: Array1<i32> = (0..256).collect();
823+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
824+
let mut iter = a.iter();
825+
assert_eq!(iter.nth_back(a.len()), None);
826+
assert_eq!(iter.next(), None);
827+
}
828+
829+
#[test]
830+
fn nth_back_partially_consumed() {
831+
let mut a: Array1<i32> = (0..256).collect();
832+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
833+
let mut iter = a.iter();
834+
iter.next();
835+
iter.next_back();
836+
assert_eq!(iter.len(), a.len() - 2);
837+
assert_eq!(iter.nth_back(1), Some(&a[a.len() - 3]));
838+
assert_eq!(iter.len(), a.len() - 4);
839+
assert_eq!(iter.nth_back(a.len() - 6), Some(&a[2]));
840+
assert_eq!(iter.len(), 1);
841+
assert_eq!(iter.next(), Some(&a[1]));
842+
assert_eq!(iter.len(), 0);
843+
assert_eq!(iter.next(), None);
844+
assert_eq!(iter.next_back(), None);
845+
}
846+
779847
#[test]
780848
fn test_rfold() {
781849
{

0 commit comments

Comments
 (0)