Skip to content

Commit 08068a8

Browse files
committed
FEAT: Add method squeeze_into
This method can squeeze into a particular dimensionality. Squeezing means removing axes of length 1. When squeezing to a particular dimensionality, we may have to still pad out the shape with extra 1-shape axes to fill the dimensionality.
1 parent 9cea3c7 commit 08068a8

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

src/dimension/mod.rs

+132
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,53 @@ where
762762
*strides = new_strides;
763763
}
764764

765+
/// Remove axes with length one, except never removing the last axis.
766+
pub(crate) fn squeeze_into<D, E>(dim: &D, strides: &D) -> Result<(E, E), ShapeError>
767+
where
768+
D: Dimension,
769+
E: Dimension,
770+
{
771+
debug_assert_eq!(dim.ndim(), strides.ndim());
772+
773+
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
774+
let mut ndim_new = 0;
775+
for &d in dim.slice() {
776+
if d != 1 { ndim_new += 1; }
777+
}
778+
let mut fill_ones = 0;
779+
if let Some(e_ndim) = E::NDIM {
780+
if e_ndim < ndim_new {
781+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
782+
}
783+
fill_ones = e_ndim - ndim_new;
784+
ndim_new = e_ndim;
785+
} else {
786+
// dynamic-dimensional
787+
// use minimum one dimension unless input has less than one dim
788+
if dim.ndim() > 0 && ndim_new == 0 {
789+
ndim_new = 1;
790+
fill_ones = 1;
791+
}
792+
}
793+
794+
let mut new_dim = E::zeros(ndim_new);
795+
let mut new_strides = E::zeros(ndim_new);
796+
let mut i = 0;
797+
while i < fill_ones {
798+
new_dim[i] = 1;
799+
new_strides[i] = 1;
800+
i += 1;
801+
}
802+
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
803+
if d != 1 {
804+
new_dim[i] = d;
805+
new_strides[i] = s;
806+
i += 1;
807+
}
808+
}
809+
Ok((new_dim, new_strides))
810+
}
811+
765812

766813
/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
767814
/// stride
@@ -1148,6 +1195,91 @@ mod test {
11481195
assert_eq!(s, sans);
11491196
}
11501197

1198+
#[test]
1199+
#[cfg(feature = "std")]
1200+
fn test_squeeze_into() {
1201+
use super::squeeze_into;
1202+
1203+
let dyndim = Dim::<&[usize]>;
1204+
1205+
// squeeze to ixdyn
1206+
let d = dyndim(&[1, 2, 1, 1, 3, 1]);
1207+
let s = dyndim(&[!0, !0, !0, 9, 10, !0]);
1208+
let dans = dyndim(&[2, 3]);
1209+
let sans = dyndim(&[!0, 10]);
1210+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1211+
assert_eq!(d2, dans);
1212+
assert_eq!(s2, sans);
1213+
1214+
// squeeze to ixdyn does not go below 1D
1215+
let d = dyndim(&[1, 1]);
1216+
let s = dyndim(&[3, 4]);
1217+
let dans = dyndim(&[1]);
1218+
let sans = dyndim(&[1]);
1219+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1220+
assert_eq!(d2, dans);
1221+
assert_eq!(s2, sans);
1222+
1223+
let d = Dim([1, 1]);
1224+
let s = Dim([3, 4]);
1225+
let dans = Dim([1]);
1226+
let sans = Dim([1]);
1227+
let (d2, s2) = squeeze_into::<_, Ix1>(&d, &s).unwrap();
1228+
assert_eq!(d2, dans);
1229+
assert_eq!(s2, sans);
1230+
1231+
// squeeze to zero-dim
1232+
let (d2, s2) = squeeze_into::<_, Ix0>(&d, &s).unwrap();
1233+
assert_eq!(d2, Ix0());
1234+
assert_eq!(s2, Ix0());
1235+
1236+
let d = Dim([0, 1, 3, 4]);
1237+
let s = Dim([2, 3, 4, 5]);
1238+
let dans = Dim([0, 3, 4]);
1239+
let sans = Dim([2, 4, 5]);
1240+
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
1241+
assert_eq!(d2, dans);
1242+
assert_eq!(s2, sans);
1243+
1244+
// Pad with ones
1245+
let d = Dim([0, 1, 3, 1]);
1246+
let s = Dim([2, 3, 4, 5]);
1247+
let dans = Dim([1, 0, 3]);
1248+
let sans = Dim([1, 2, 4]);
1249+
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
1250+
assert_eq!(d2, dans);
1251+
assert_eq!(s2, sans);
1252+
1253+
// Try something that doesn't fit
1254+
let d = Dim([0, 1, 3, 1]);
1255+
let s = Dim([2, 3, 4, 5]);
1256+
let res = squeeze_into::<_, Ix1>(&d, &s);
1257+
assert!(res.is_err());
1258+
let res = squeeze_into::<_, Ix0>(&d, &s);
1259+
assert!(res.is_err());
1260+
1261+
// Squeeze 0d to 0d
1262+
let d = Dim([]);
1263+
let s = Dim([]);
1264+
let res = squeeze_into::<_, Ix0>(&d, &s);
1265+
assert!(res.is_ok());
1266+
// grow 0d to 2d
1267+
let dans = Dim([1, 1]);
1268+
let sans = Dim([1, 1]);
1269+
let (d2, s2) = squeeze_into::<_, Ix2>(&d, &s).unwrap();
1270+
assert_eq!(d2, dans);
1271+
assert_eq!(s2, sans);
1272+
1273+
// Squeeze 0d to 0d dynamic
1274+
let d = dyndim(&[]);
1275+
let s = dyndim(&[]);
1276+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1277+
let dans = d;
1278+
let sans = s;
1279+
assert_eq!(d2, dans);
1280+
assert_eq!(s2, sans);
1281+
}
1282+
11511283
#[test]
11521284
fn test_merge_axes_from_the_back() {
11531285
let dyndim = Dim::<&[usize]>;

0 commit comments

Comments
 (0)