Skip to content

Commit 1d51f70

Browse files
Merge pull request rust-ndarray#3 from jturner314/pairwise-summation
Improve pairwise summation
2 parents bbc4a75 + 82453df commit 1d51f70

File tree

3 files changed

+77
-27
lines changed

3 files changed

+77
-27
lines changed

benches/numeric.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
#![feature(test)]
33
extern crate test;
4-
use test::Bencher;
4+
use test::{black_box, Bencher};
55

66
extern crate ndarray;
77
use ndarray::prelude::*;
@@ -65,6 +65,38 @@ fn contiguous_sum_1e2(bench: &mut Bencher)
6565
});
6666
}
6767

68+
#[bench]
69+
fn contiguous_sum_ix3_1e2(bench: &mut Bencher)
70+
{
71+
let n = 1e2 as usize;
72+
let a = Array::linspace(-1e6, 1e6, n * n * n)
73+
.into_shape([n, n, n])
74+
.unwrap();
75+
bench.iter(|| black_box(&a).sum());
76+
}
77+
78+
#[bench]
79+
fn inner_discontiguous_sum_ix3_1e2(bench: &mut Bencher)
80+
{
81+
let n = 1e2 as usize;
82+
let a = Array::linspace(-1e6, 1e6, n * n * 2*n)
83+
.into_shape([n, n, 2*n])
84+
.unwrap();
85+
let v = a.slice(s![.., .., ..;2]);
86+
bench.iter(|| black_box(&v).sum());
87+
}
88+
89+
#[bench]
90+
fn middle_discontiguous_sum_ix3_1e2(bench: &mut Bencher)
91+
{
92+
let n = 1e2 as usize;
93+
let a = Array::linspace(-1e6, 1e6, n * 2*n * n)
94+
.into_shape([n, 2*n, n])
95+
.unwrap();
96+
let v = a.slice(s![.., ..;2, ..]);
97+
bench.iter(|| black_box(&v).sum());
98+
}
99+
68100
#[bench]
69101
fn sum_by_row_1e4(bench: &mut Bencher)
70102
{
@@ -88,3 +120,15 @@ fn sum_by_col_1e4(bench: &mut Bencher)
88120
a.sum_axis(Axis(1))
89121
});
90122
}
123+
124+
#[bench]
125+
fn sum_by_middle_1e2(bench: &mut Bencher)
126+
{
127+
let n = 1e2 as usize;
128+
let a = Array::linspace(-1e6, 1e6, n * n * n)
129+
.into_shape([n, n, n])
130+
.unwrap();
131+
bench.iter(|| {
132+
a.sum_axis(Axis(1))
133+
});
134+
}

src/numeric/impl_numeric.rs

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
use std::ops::{Add, Div, Mul};
1010
use num_traits::{self, Zero, Float, FromPrimitive};
11-
use itertools::free::enumerate;
1211

1312
use crate::imp_prelude::*;
1413
use crate::numeric_util;
@@ -33,17 +32,10 @@ impl<A, S, D> ArrayBase<S, D>
3332
where A: Clone + Add<Output=A> + num_traits::Zero,
3433
{
3534
if let Some(slc) = self.as_slice_memory_order() {
36-
return numeric_util::pairwise_sum(&slc)
37-
}
38-
let mut sum = A::zero();
39-
for row in self.inner_rows() {
40-
if let Some(slc) = row.as_slice() {
41-
sum = sum + numeric_util::pairwise_sum(&slc);
42-
} else {
43-
sum = sum + numeric_util::iterator_pairwise_sum(row.iter());
44-
}
35+
numeric_util::pairwise_sum(&slc)
36+
} else {
37+
numeric_util::iterator_pairwise_sum(self.iter())
4538
}
46-
sum
4739
}
4840

4941
/// Return the sum of all elements in the array.
@@ -104,16 +96,14 @@ impl<A, S, D> ArrayBase<S, D>
10496
D: RemoveAxis,
10597
{
10698
let n = self.len_of(axis);
107-
let stride = self.strides()[axis.index()];
108-
if self.ndim() == 2 && stride == 1 {
99+
if self.stride_of(axis) == 1 {
109100
// contiguous along the axis we are summing
110101
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
111-
let ax = axis.index();
112-
for (i, elt) in enumerate(&mut res) {
113-
*elt = self.index_axis(Axis(1 - ax), i).sum();
114-
}
102+
Zip::from(&mut res)
103+
.and(self.lanes(axis))
104+
.apply(|sum, lane| *sum = lane.sum());
115105
res
116-
} else if self.len_of(axis) <= numeric_util::NAIVE_SUM_THRESHOLD {
106+
} else if n <= numeric_util::NAIVE_SUM_THRESHOLD {
117107
self.fold_axis(axis, A::zero(), |acc, x| acc.clone() + x.clone())
118108
} else {
119109
let (v1, v2) = self.view().split_at(axis, n / 2);

src/numeric_util.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,18 @@ where
5656
I: Iterator<Item=&'a A>,
5757
A: Clone + Add<Output=A> + Zero,
5858
{
59-
let mut partial_sums = vec![];
60-
let mut partial_sum = A::zero();
61-
for (i, x) in iter.enumerate() {
62-
partial_sum = partial_sum + x.clone();
63-
if i % NAIVE_SUM_THRESHOLD == NAIVE_SUM_THRESHOLD - 1 {
59+
let (len, _) = iter.size_hint();
60+
let cap = len.saturating_sub(1) / NAIVE_SUM_THRESHOLD + 1; // ceiling of division
61+
let mut partial_sums = Vec::with_capacity(cap);
62+
let (_, last_sum) = iter.fold((0, A::zero()), |(count, partial_sum), x| {
63+
if count < NAIVE_SUM_THRESHOLD {
64+
(count + 1, partial_sum + x.clone())
65+
} else {
6466
partial_sums.push(partial_sum);
65-
partial_sum = A::zero();
67+
(1, x.clone())
6668
}
67-
}
68-
partial_sums.push(partial_sum);
69+
});
70+
partial_sums.push(last_sum);
6971

7072
pure_pairwise_sum(&partial_sums)
7173
}
@@ -205,3 +207,17 @@ pub fn unrolled_eq<A>(xs: &[A], ys: &[A]) -> bool
205207

206208
true
207209
}
210+
211+
#[cfg(test)]
212+
mod tests {
213+
use quickcheck::quickcheck;
214+
use std::num::Wrapping;
215+
use super::iterator_pairwise_sum;
216+
217+
quickcheck! {
218+
fn iterator_pairwise_sum_is_correct(xs: Vec<i32>) -> bool {
219+
let xs: Vec<_> = xs.into_iter().map(|x| Wrapping(x)).collect();
220+
iterator_pairwise_sum(xs.iter()) == xs.iter().sum()
221+
}
222+
}
223+
}

0 commit comments

Comments
 (0)