Skip to content

Commit 35065cb

Browse files
authored
Merge pull request #684 from jturner314/zip-fold
Add .fold() method to Zip
2 parents 52bc1f9 + 27c010d commit 35065cb

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

src/zip/mod.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,8 +755,44 @@ macro_rules! map_impl {
755755
/// Apply a fold function to all elements of the input arrays,
756756
/// visiting elements in lock step.
757757
///
758-
/// The fold continues while the return value is a
759-
/// `FoldWhile::Continue`.
758+
/// # Example
759+
///
760+
/// The expression `tr(AᵀB)` can be more efficiently computed as
761+
/// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
762+
/// elements of the entry-wise product). It would be possible to
763+
/// evaluate this expression by first computing the entry-wise
764+
/// product, `A∘B`, and then computing the elementwise sum of that
765+
/// product, but it's possible to do this in a single loop (and
766+
/// avoid an extra heap allocation if `A` and `B` can't be
767+
/// consumed) by using `Zip`:
768+
///
769+
/// ```
770+
/// use ndarray::{array, Zip};
771+
///
772+
/// let a = array![[1, 5], [3, 7]];
773+
/// let b = array![[2, 4], [8, 6]];
774+
///
775+
/// // Without using `Zip`. This involves two loops and an extra
776+
/// // heap allocation for the result of `&a * &b`.
777+
/// let sum_prod_nonzip = (&a * &b).sum();
778+
/// // Using `Zip`. This is a single loop without any heap allocations.
779+
/// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
780+
///
781+
/// assert_eq!(sum_prod_nonzip, sum_prod_zip);
782+
/// ```
783+
pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
784+
where
785+
F: FnMut(Acc, $($p::Item),*) -> Acc,
786+
{
787+
self.apply_core(acc, move |acc, args| {
788+
let ($($p,)*) = args;
789+
FoldWhile::Continue(function(acc, $($p),*))
790+
}).into_inner()
791+
}
792+
793+
/// Apply a fold function to the input arrays while the return
794+
/// value is `FoldWhile::Continue`, visiting elements in lock step.
795+
///
760796
pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
761797
-> FoldWhile<Acc>
762798
where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>

0 commit comments

Comments
 (0)