Skip to content

Commit abda078

Browse files
authored
Merge pull request #547 from rust-ndarray/iter-rfold
Implement Iterator::rfold where we can
2 parents 39ccd4a + ca3f987 commit abda078

File tree

3 files changed

+110
-3
lines changed

3 files changed

+110
-3
lines changed

benches/iter.rs

+22
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use rawpointer::PointerExt;
1010
extern crate ndarray;
1111
use ndarray::prelude::*;
1212
use ndarray::{Zip, FoldWhile};
13+
use ndarray::Slice;
1314

1415
#[bench]
1516
fn iter_sum_2d_regular(bench: &mut Bencher)
@@ -342,3 +343,24 @@ fn indexed_iter_3d_dyn(bench: &mut Bencher) {
342343
}
343344
})
344345
}
346+
347+
#[bench]
348+
fn iter_sum_1d_strided_fold(bench: &mut Bencher)
349+
{
350+
let mut a = Array::<u64, _>::ones(10240);
351+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
352+
bench.iter(|| {
353+
a.iter().fold(0, |acc, &x| acc + x)
354+
});
355+
}
356+
357+
#[bench]
358+
fn iter_sum_1d_strided_rfold(bench: &mut Bencher)
359+
{
360+
let mut a = Array::<u64, _>::ones(10240);
361+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
362+
bench.iter(|| {
363+
a.iter().rfold(0, |acc, &x| acc + x)
364+
});
365+
}
366+

src/iterators/mod.rs

+45
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,23 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1> {
148148

149149
unsafe { Some(self.ptr.offset(offset)) }
150150
}
151+
152+
fn rfold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
153+
where G: FnMut(Acc, *mut A) -> Acc,
154+
{
155+
let mut accum = init;
156+
if let Some(index) = self.index {
157+
let elem_index = index[0];
158+
unsafe {
159+
// self.dim[0] is the current length
160+
while self.dim[0] > elem_index {
161+
self.dim[0] -= 1;
162+
accum = g(accum, self.ptr.offset(Ix1::stride_offset(&self.dim, &self.strides)));
163+
}
164+
}
165+
}
166+
accum
167+
}
151168
}
152169

153170
clone_bounds!(
@@ -206,6 +223,14 @@ impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> {
206223
fn next_back(&mut self) -> Option<&'a A> {
207224
self.inner.next_back().map(|p| unsafe { &*p })
208225
}
226+
227+
fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
228+
where G: FnMut(Acc, Self::Item) -> Acc,
229+
{
230+
unsafe {
231+
self.inner.rfold(init, move |acc, ptr| g(acc, &*ptr))
232+
}
233+
}
209234
}
210235

211236
impl<'a, A, D> ExactSizeIterator for ElementsBase<'a, A, D>
@@ -370,6 +395,12 @@ impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> {
370395
fn next_back(&mut self) -> Option<&'a A> {
371396
either_mut!(self.inner, iter => iter.next_back())
372397
}
398+
399+
fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
400+
where G: FnMut(Acc, Self::Item) -> Acc
401+
{
402+
either!(self.inner, iter => iter.rfold(init, g))
403+
}
373404
}
374405

375406
impl<'a, A, D> ExactSizeIterator for Iter<'a, A, D>
@@ -431,6 +462,12 @@ impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> {
431462
fn next_back(&mut self) -> Option<&'a mut A> {
432463
either_mut!(self.inner, iter => iter.next_back())
433464
}
465+
466+
fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
467+
where G: FnMut(Acc, Self::Item) -> Acc
468+
{
469+
either!(self.inner, iter => iter.rfold(init, g))
470+
}
434471
}
435472

436473
impl<'a, A, D> ExactSizeIterator for IterMut<'a, A, D>
@@ -466,6 +503,14 @@ impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> {
466503
fn next_back(&mut self) -> Option<&'a mut A> {
467504
self.inner.next_back().map(|p| unsafe { &mut *p })
468505
}
506+
507+
fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
508+
where G: FnMut(Acc, Self::Item) -> Acc
509+
{
510+
unsafe {
511+
self.inner.rfold(init, move |acc, ptr| g(acc, &mut *ptr))
512+
}
513+
}
469514
}
470515

471516
impl<'a, A, D> ExactSizeIterator for ElementsBaseMut<'a, A, D>

tests/iterators.rs

+43-3
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22
extern crate ndarray;
33
extern crate itertools;
44

5-
use ndarray::{Array0, Array2};
6-
use ndarray::RcArray;
5+
use ndarray::prelude::*;
76
use ndarray::Ix;
87
use ndarray::{
9-
ArrayBase,
108
Data,
119
Dimension,
1210
aview1,
1311
arr2,
1412
arr3,
1513
Axis,
1614
indices,
15+
Slice,
1716
};
1817

1918
use itertools::assert_equal;
@@ -498,3 +497,44 @@ fn test_fold() {
498497
a += 1;
499498
assert_eq!(a.iter().fold(0, |acc, &x| acc + x), 1);
500499
}
500+
501+
#[test]
502+
fn test_rfold() {
503+
{
504+
let mut a = Array1::<i32>::default(256);
505+
a += 1;
506+
let mut iter = a.iter();
507+
iter.next();
508+
assert_eq!(iter.rfold(0, |acc, &x| acc + x), a.sum() - 1);
509+
}
510+
511+
// Test strided arrays
512+
{
513+
let mut a = Array1::<i32>::default(256);
514+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
515+
a += 1;
516+
let mut iter = a.iter();
517+
iter.next();
518+
assert_eq!(iter.rfold(0, |acc, &x| acc + x), a.sum() - 1);
519+
}
520+
521+
{
522+
let mut a = Array1::<i32>::default(256);
523+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, -2));
524+
a += 1;
525+
let mut iter = a.iter();
526+
iter.next();
527+
assert_eq!(iter.rfold(0, |acc, &x| acc + x), a.sum() - 1);
528+
}
529+
530+
// Test order
531+
{
532+
let mut a = Array1::from_iter(0..20);
533+
a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2));
534+
let mut iter = a.iter();
535+
iter.next();
536+
let output = iter.rfold(Vec::new(),
537+
|mut acc, elt| { acc.push(*elt); acc });
538+
assert_eq!(Array1::from_vec(output), Array::from_iter((1..10).rev().map(|i| i * 2)));
539+
}
540+
}

0 commit comments

Comments
 (0)