Skip to content

Commit 2d6bcde

Browse files
committed
FEAT: Add method squeeze_into
This method can squeeze into a particular dimensionality. Squeezing means removing axes of length 1. When squeezing to a particular dimensionality, we may have to still pad out the shape with extra 1-shape axes to fill the dimensionality.
1 parent 12b9525 commit 2d6bcde

File tree

1 file changed

+134
-9
lines changed

1 file changed

+134
-9
lines changed

src/dimension/mod.rs

+134-9
Original file line numberDiff line numberDiff line change
@@ -729,37 +729,77 @@ where
729729
}
730730

731731
/// Remove axes with length one, except never removing the last axis.
732+
///
733+
/// This function is a no-op for const dim.
732734
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
733735
where
734736
D: Dimension,
735737
{
736738
if let Some(_) = D::NDIM {
737739
return;
738740
}
741+
742+
// infallible for dyn dim
743+
let (d, s) = squeeze_into(dim, strides).unwrap();
744+
*dim = d;
745+
*strides = s;
746+
}
747+
748+
/// Remove axes with length one, except never removing the last axis.
749+
///
750+
/// Return an error if there are more non-unitary dimensions than can be stored
751+
/// in `E`. Infallible for dyn dim.
752+
///
753+
/// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
754+
/// dynamic 0D, the output can be too.
755+
///
756+
/// For const dim, this may instead pad the dimensionality with ones if it needs
757+
/// to grow to fill the target dimensionality; the dimension is padded in the
758+
/// start.
759+
pub(crate) fn squeeze_into<D, E>(dim: &D, strides: &D) -> Result<(E, E), ShapeError>
760+
where
761+
D: Dimension,
762+
E: Dimension,
763+
{
739764
debug_assert_eq!(dim.ndim(), strides.ndim());
740765

741766
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
742767
let mut ndim_new = 0;
743768
for &d in dim.slice() {
744769
if d != 1 { ndim_new += 1; }
745770
}
746-
ndim_new = Ord::max(1, ndim_new);
747-
let mut new_dim = D::zeros(ndim_new);
748-
let mut new_strides = D::zeros(ndim_new);
771+
let mut fill_ones = 0;
772+
if let Some(e_ndim) = E::NDIM {
773+
if e_ndim < ndim_new {
774+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
775+
}
776+
fill_ones = e_ndim - ndim_new;
777+
ndim_new = e_ndim;
778+
} else {
779+
// dynamic-dimensional
780+
// use minimum one dimension unless input has less than one dim
781+
if dim.ndim() > 0 && ndim_new == 0 {
782+
ndim_new = 1;
783+
fill_ones = 1;
784+
}
785+
}
786+
787+
let mut new_dim = E::zeros(ndim_new);
788+
let mut new_strides = E::zeros(ndim_new);
749789
let mut i = 0;
790+
while i < fill_ones {
791+
new_dim[i] = 1;
792+
new_strides[i] = 1;
793+
i += 1;
794+
}
750795
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
751796
if d != 1 {
752797
new_dim[i] = d;
753798
new_strides[i] = s;
754799
i += 1;
755800
}
756801
}
757-
if i == 0 {
758-
new_dim[i] = 1;
759-
new_strides[i] = 1;
760-
}
761-
*dim = new_dim;
762-
*strides = new_strides;
802+
Ok((new_dim, new_strides))
763803
}
764804

765805

@@ -1148,6 +1188,91 @@ mod test {
11481188
assert_eq!(s, sans);
11491189
}
11501190

1191+
#[test]
1192+
#[cfg(feature = "std")]
1193+
fn test_squeeze_into() {
1194+
use super::squeeze_into;
1195+
1196+
let dyndim = Dim::<&[usize]>;
1197+
1198+
// squeeze to ixdyn
1199+
let d = dyndim(&[1, 2, 1, 1, 3, 1]);
1200+
let s = dyndim(&[!0, !0, !0, 9, 10, !0]);
1201+
let dans = dyndim(&[2, 3]);
1202+
let sans = dyndim(&[!0, 10]);
1203+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1204+
assert_eq!(d2, dans);
1205+
assert_eq!(s2, sans);
1206+
1207+
// squeeze to ixdyn does not go below 1D
1208+
let d = dyndim(&[1, 1]);
1209+
let s = dyndim(&[3, 4]);
1210+
let dans = dyndim(&[1]);
1211+
let sans = dyndim(&[1]);
1212+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1213+
assert_eq!(d2, dans);
1214+
assert_eq!(s2, sans);
1215+
1216+
let d = Dim([1, 1]);
1217+
let s = Dim([3, 4]);
1218+
let dans = Dim([1]);
1219+
let sans = Dim([1]);
1220+
let (d2, s2) = squeeze_into::<_, Ix1>(&d, &s).unwrap();
1221+
assert_eq!(d2, dans);
1222+
assert_eq!(s2, sans);
1223+
1224+
// squeeze to zero-dim
1225+
let (d2, s2) = squeeze_into::<_, Ix0>(&d, &s).unwrap();
1226+
assert_eq!(d2, Ix0());
1227+
assert_eq!(s2, Ix0());
1228+
1229+
let d = Dim([0, 1, 3, 4]);
1230+
let s = Dim([2, 3, 4, 5]);
1231+
let dans = Dim([0, 3, 4]);
1232+
let sans = Dim([2, 4, 5]);
1233+
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
1234+
assert_eq!(d2, dans);
1235+
assert_eq!(s2, sans);
1236+
1237+
// Pad with ones
1238+
let d = Dim([0, 1, 3, 1]);
1239+
let s = Dim([2, 3, 4, 5]);
1240+
let dans = Dim([1, 0, 3]);
1241+
let sans = Dim([1, 2, 4]);
1242+
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
1243+
assert_eq!(d2, dans);
1244+
assert_eq!(s2, sans);
1245+
1246+
// Try something that doesn't fit
1247+
let d = Dim([0, 1, 3, 1]);
1248+
let s = Dim([2, 3, 4, 5]);
1249+
let res = squeeze_into::<_, Ix1>(&d, &s);
1250+
assert!(res.is_err());
1251+
let res = squeeze_into::<_, Ix0>(&d, &s);
1252+
assert!(res.is_err());
1253+
1254+
// Squeeze 0d to 0d
1255+
let d = Dim([]);
1256+
let s = Dim([]);
1257+
let res = squeeze_into::<_, Ix0>(&d, &s);
1258+
assert!(res.is_ok());
1259+
// grow 0d to 2d
1260+
let dans = Dim([1, 1]);
1261+
let sans = Dim([1, 1]);
1262+
let (d2, s2) = squeeze_into::<_, Ix2>(&d, &s).unwrap();
1263+
assert_eq!(d2, dans);
1264+
assert_eq!(s2, sans);
1265+
1266+
// Squeeze 0d to 0d dynamic
1267+
let d = dyndim(&[]);
1268+
let s = dyndim(&[]);
1269+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1270+
let dans = d;
1271+
let sans = s;
1272+
assert_eq!(d2, dans);
1273+
assert_eq!(s2, sans);
1274+
}
1275+
11511276
#[test]
11521277
fn test_merge_axes_from_the_back() {
11531278
let dyndim = Dim::<&[usize]>;

0 commit comments

Comments
 (0)