Skip to content

Commit 8ed5720

Browse files
authored
Merge pull request #495 from jturner314/merge-more-axes
Support merging axes in cases with lengths <= 1
2 parents 245cfbb + a2e166c commit 8ed5720

File tree

3 files changed

+135
-11
lines changed

3 files changed

+135
-11
lines changed

src/dimension/mod.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -287,17 +287,27 @@ pub fn do_slice(
287287
pub fn merge_axes<D>(dim: &mut D, strides: &mut D, take: Axis, into: Axis) -> bool
288288
where D: Dimension,
289289
{
290-
let il = dim.axis(into);
291-
let is = strides.axis(into) as Ixs;
292-
let tl = dim.axis(take);
293-
let ts = strides.axis(take) as Ixs;
294-
if il as Ixs * is != ts {
295-
return false;
290+
let into_len = dim.axis(into);
291+
let into_stride = strides.axis(into) as isize;
292+
let take_len = dim.axis(take);
293+
let take_stride = strides.axis(take) as isize;
294+
let merged_len = into_len * take_len;
295+
if take_len <= 1 {
296+
dim.set_axis(into, merged_len);
297+
dim.set_axis(take, if merged_len == 0 { 0 } else { 1 });
298+
true
299+
} else if into_len <= 1 {
300+
strides.set_axis(into, take_stride as usize);
301+
dim.set_axis(into, merged_len);
302+
dim.set_axis(take, if merged_len == 0 { 0 } else { 1 });
303+
true
304+
} else if take_stride == into_len as isize * into_stride {
305+
dim.set_axis(into, merged_len);
306+
dim.set_axis(take, 1);
307+
true
308+
} else {
309+
false
296310
}
297-
// merge them
298-
dim.set_axis(into, il * tl);
299-
dim.set_axis(take, 1);
300-
true
301311
}
302312

303313

src/impl_methods.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1478,12 +1478,35 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14781478

14791479
/// If possible, merge in the axis `take` to `into`.
14801480
///
1481+
/// Returns `true` iff the axes are now merged.
1482+
///
1483+
/// This method merges the axes if movement along the two original axes
1484+
/// (moving fastest along the `into` axis) can be equivalently represented
1485+
/// as movement along one (merged) axis. Merging the axes preserves this
1486+
/// order in the merged axis. If `take` and `into` are the same axis, then
1487+
/// the axis is "merged" if its length is ≤ 1.
1488+
///
1489+
/// If the return value is `true`, then the following hold:
1490+
///
1491+
/// * The new length of the `into` axis is the product of the original
1492+
/// lengths of the two axes.
1493+
///
1494+
/// * The new length of the `take` axis is 0 if the product of the original
1495+
/// lengths of the two axes is 0, and 1 otherwise.
1496+
///
1497+
/// If the return value is `false`, then merging is not possible, and the
1498+
/// original shape and strides have been preserved.
1499+
///
1500+
/// Note that the ordering constraint means that if it's possible to merge
1501+
/// `take` into `into`, it's usually not possible to merge `into` into
1502+
/// `take`, and vice versa.
1503+
///
14811504
/// ```
14821505
/// use ndarray::Array3;
14831506
/// use ndarray::Axis;
14841507
///
14851508
/// let mut a = Array3::<f64>::zeros((2, 3, 4));
1486-
/// a.merge_axes(Axis(1), Axis(2));
1509+
/// assert!(a.merge_axes(Axis(1), Axis(2)));
14871510
/// assert_eq!(a.shape(), &[2, 1, 12]);
14881511
/// ```
14891512
///

tests/array.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,97 @@ fn diag()
573573
assert_eq!(d.dim(), 1);
574574
}
575575

576+
/// Check that the merged shape is correct.
577+
///
578+
/// Note that this does not check the strides in the "merged" case!
579+
#[test]
580+
fn merge_axes() {
581+
macro_rules! assert_merged {
582+
($arr:expr, $slice:expr, $take:expr, $into:expr) => {
583+
let mut v = $arr.slice($slice);
584+
let merged_len = v.len_of(Axis($take)) * v.len_of(Axis($into));
585+
assert!(v.merge_axes(Axis($take), Axis($into)));
586+
assert_eq!(v.len_of(Axis($take)), if merged_len == 0 { 0 } else { 1 });
587+
assert_eq!(v.len_of(Axis($into)), merged_len);
588+
}
589+
}
590+
macro_rules! assert_not_merged {
591+
($arr:expr, $slice:expr, $take:expr, $into:expr) => {
592+
let mut v = $arr.slice($slice);
593+
let old_dim = v.raw_dim();
594+
let old_strides = v.strides().to_owned();
595+
assert!(!v.merge_axes(Axis($take), Axis($into)));
596+
assert_eq!(v.raw_dim(), old_dim);
597+
assert_eq!(v.strides(), &old_strides[..]);
598+
}
599+
}
600+
601+
let a = Array4::<u8>::zeros((3, 4, 5, 4));
602+
603+
assert_not_merged!(a, s![.., .., .., ..], 0, 0);
604+
assert_merged!(a, s![.., .., .., ..], 0, 1);
605+
assert_not_merged!(a, s![.., .., .., ..], 0, 2);
606+
assert_not_merged!(a, s![.., .., .., ..], 0, 3);
607+
assert_not_merged!(a, s![.., .., .., ..], 1, 0);
608+
assert_not_merged!(a, s![.., .., .., ..], 1, 1);
609+
assert_merged!(a, s![.., .., .., ..], 1, 2);
610+
assert_not_merged!(a, s![.., .., .., ..], 1, 3);
611+
assert_not_merged!(a, s![.., .., .., ..], 2, 1);
612+
assert_not_merged!(a, s![.., .., .., ..], 2, 2);
613+
assert_merged!(a, s![.., .., .., ..], 2, 3);
614+
assert_not_merged!(a, s![.., .., .., ..], 3, 0);
615+
assert_not_merged!(a, s![.., .., .., ..], 3, 1);
616+
assert_not_merged!(a, s![.., .., .., ..], 3, 2);
617+
assert_not_merged!(a, s![.., .., .., ..], 3, 3);
618+
619+
assert_merged!(a, s![.., .., .., ..;2], 0, 1);
620+
assert_not_merged!(a, s![.., .., .., ..;2], 1, 0);
621+
assert_merged!(a, s![.., .., .., ..;2], 1, 2);
622+
assert_not_merged!(a, s![.., .., .., ..;2], 2, 1);
623+
assert_merged!(a, s![.., .., .., ..;2], 2, 3);
624+
assert_not_merged!(a, s![.., .., .., ..;2], 3, 2);
625+
626+
assert_merged!(a, s![.., .., .., ..3], 0, 1);
627+
assert_not_merged!(a, s![.., .., .., ..3], 1, 0);
628+
assert_merged!(a, s![.., .., .., ..3], 1, 2);
629+
assert_not_merged!(a, s![.., .., .., ..3], 2, 1);
630+
assert_not_merged!(a, s![.., .., .., ..3], 2, 3);
631+
632+
assert_merged!(a, s![.., .., ..;2, ..], 0, 1);
633+
assert_not_merged!(a, s![.., .., ..;2, ..], 1, 0);
634+
assert_not_merged!(a, s![.., .., ..;2, ..], 1, 2);
635+
assert_not_merged!(a, s![.., .., ..;2, ..], 2, 3);
636+
637+
assert_merged!(a, s![.., ..;2, .., ..], 0, 1);
638+
assert_not_merged!(a, s![.., ..;2, .., ..], 1, 0);
639+
assert_not_merged!(a, s![.., ..;2, .., ..], 1, 2);
640+
assert_merged!(a, s![.., ..;2, .., ..], 2, 3);
641+
assert_not_merged!(a, s![.., ..;2, .., ..], 3, 2);
642+
643+
let a = Array4::<u8>::zeros((3, 1, 5, 1).f());
644+
assert_merged!(a, s![.., .., ..;2, ..], 0, 1);
645+
assert_merged!(a, s![.., .., ..;2, ..], 0, 3);
646+
assert_merged!(a, s![.., .., ..;2, ..], 1, 0);
647+
assert_merged!(a, s![.., .., ..;2, ..], 1, 1);
648+
assert_merged!(a, s![.., .., ..;2, ..], 1, 2);
649+
assert_merged!(a, s![.., .., ..;2, ..], 1, 3);
650+
assert_merged!(a, s![.., .., ..;2, ..], 2, 1);
651+
assert_merged!(a, s![.., .., ..;2, ..], 2, 3);
652+
assert_merged!(a, s![.., .., ..;2, ..], 3, 0);
653+
assert_merged!(a, s![.., .., ..;2, ..], 3, 1);
654+
assert_merged!(a, s![.., .., ..;2, ..], 3, 2);
655+
assert_merged!(a, s![.., .., ..;2, ..], 3, 3);
656+
657+
let a = Array4::<u8>::zeros((3, 0, 5, 1));
658+
assert_merged!(a, s![.., .., ..;2, ..], 0, 1);
659+
assert_merged!(a, s![.., .., ..;2, ..], 1, 1);
660+
assert_merged!(a, s![.., .., ..;2, ..], 2, 1);
661+
assert_merged!(a, s![.., .., ..;2, ..], 3, 1);
662+
assert_merged!(a, s![.., .., ..;2, ..], 1, 0);
663+
assert_merged!(a, s![.., .., ..;2, ..], 1, 2);
664+
assert_merged!(a, s![.., .., ..;2, ..], 1, 3);
665+
}
666+
576667
#[test]
577668
fn swapaxes()
578669
{

0 commit comments

Comments
 (0)