@@ -729,37 +729,77 @@ where
729
729
}
730
730
731
731
/// Remove axes with length one, except never removing the last axis.
732
+ ///
733
+ /// This function is a no-op for const dim.
732
734
pub ( crate ) fn squeeze < D > ( dim : & mut D , strides : & mut D )
733
735
where
734
736
D : Dimension ,
735
737
{
736
738
if let Some ( _) = D :: NDIM {
737
739
return ;
738
740
}
741
+
742
+ // infallible for dyn dim
743
+ let ( d, s) = squeeze_into ( dim, strides) . unwrap ( ) ;
744
+ * dim = d;
745
+ * strides = s;
746
+ }
747
+
748
+ /// Remove axes with length one, except never removing the last axis.
749
+ ///
750
+ /// Return an error if there are more non-unitary dimensions than can be stored
751
+ /// in `E`. Infallible for dyn dim.
752
+ ///
753
+ /// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
754
+ /// dynamic 0D, the output can be too.
755
+ ///
756
+ /// For const dim, this may instead pad the dimensionality with ones if it needs
757
+ /// to grow to fill the target dimensionality; the dimension is padded in the
758
+ /// start.
759
+ pub ( crate ) fn squeeze_into < D , E > ( dim : & D , strides : & D ) -> Result < ( E , E ) , ShapeError >
760
+ where
761
+ D : Dimension ,
762
+ E : Dimension ,
763
+ {
739
764
debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
740
765
741
766
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
742
767
let mut ndim_new = 0 ;
743
768
for & d in dim. slice ( ) {
744
769
if d != 1 { ndim_new += 1 ; }
745
770
}
746
- ndim_new = Ord :: max ( 1 , ndim_new) ;
747
- let mut new_dim = D :: zeros ( ndim_new) ;
748
- let mut new_strides = D :: zeros ( ndim_new) ;
771
+ let mut fill_ones = 0 ;
772
+ if let Some ( e_ndim) = E :: NDIM {
773
+ if e_ndim < ndim_new {
774
+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
775
+ }
776
+ fill_ones = e_ndim - ndim_new;
777
+ ndim_new = e_ndim;
778
+ } else {
779
+ // dynamic-dimensional
780
+ // use minimum one dimension unless input has less than one dim
781
+ if dim. ndim ( ) > 0 && ndim_new == 0 {
782
+ ndim_new = 1 ;
783
+ fill_ones = 1 ;
784
+ }
785
+ }
786
+
787
+ let mut new_dim = E :: zeros ( ndim_new) ;
788
+ let mut new_strides = E :: zeros ( ndim_new) ;
749
789
let mut i = 0 ;
790
+ while i < fill_ones {
791
+ new_dim[ i] = 1 ;
792
+ new_strides[ i] = 1 ;
793
+ i += 1 ;
794
+ }
750
795
for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
751
796
if d != 1 {
752
797
new_dim[ i] = d;
753
798
new_strides[ i] = s;
754
799
i += 1 ;
755
800
}
756
801
}
757
- if i == 0 {
758
- new_dim[ i] = 1 ;
759
- new_strides[ i] = 1 ;
760
- }
761
- * dim = new_dim;
762
- * strides = new_strides;
802
+ Ok ( ( new_dim, new_strides) )
763
803
}
764
804
765
805
@@ -1148,6 +1188,91 @@ mod test {
1148
1188
assert_eq ! ( s, sans) ;
1149
1189
}
1150
1190
1191
+ #[ test]
1192
+ #[ cfg( feature = "std" ) ]
1193
+ fn test_squeeze_into ( ) {
1194
+ use super :: squeeze_into;
1195
+
1196
+ let dyndim = Dim :: < & [ usize ] > ;
1197
+
1198
+ // squeeze to ixdyn
1199
+ let d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1200
+ let s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1201
+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1202
+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1203
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1204
+ assert_eq ! ( d2, dans) ;
1205
+ assert_eq ! ( s2, sans) ;
1206
+
1207
+ // squeeze to ixdyn does not go below 1D
1208
+ let d = dyndim ( & [ 1 , 1 ] ) ;
1209
+ let s = dyndim ( & [ 3 , 4 ] ) ;
1210
+ let dans = dyndim ( & [ 1 ] ) ;
1211
+ let sans = dyndim ( & [ 1 ] ) ;
1212
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1213
+ assert_eq ! ( d2, dans) ;
1214
+ assert_eq ! ( s2, sans) ;
1215
+
1216
+ let d = Dim ( [ 1 , 1 ] ) ;
1217
+ let s = Dim ( [ 3 , 4 ] ) ;
1218
+ let dans = Dim ( [ 1 ] ) ;
1219
+ let sans = Dim ( [ 1 ] ) ;
1220
+ let ( d2, s2) = squeeze_into :: < _ , Ix1 > ( & d, & s) . unwrap ( ) ;
1221
+ assert_eq ! ( d2, dans) ;
1222
+ assert_eq ! ( s2, sans) ;
1223
+
1224
+ // squeeze to zero-dim
1225
+ let ( d2, s2) = squeeze_into :: < _ , Ix0 > ( & d, & s) . unwrap ( ) ;
1226
+ assert_eq ! ( d2, Ix0 ( ) ) ;
1227
+ assert_eq ! ( s2, Ix0 ( ) ) ;
1228
+
1229
+ let d = Dim ( [ 0 , 1 , 3 , 4 ] ) ;
1230
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1231
+ let dans = Dim ( [ 0 , 3 , 4 ] ) ;
1232
+ let sans = Dim ( [ 2 , 4 , 5 ] ) ;
1233
+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1234
+ assert_eq ! ( d2, dans) ;
1235
+ assert_eq ! ( s2, sans) ;
1236
+
1237
+ // Pad with ones
1238
+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1239
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1240
+ let dans = Dim ( [ 1 , 0 , 3 ] ) ;
1241
+ let sans = Dim ( [ 1 , 2 , 4 ] ) ;
1242
+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1243
+ assert_eq ! ( d2, dans) ;
1244
+ assert_eq ! ( s2, sans) ;
1245
+
1246
+ // Try something that doesn't fit
1247
+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1248
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1249
+ let res = squeeze_into :: < _ , Ix1 > ( & d, & s) ;
1250
+ assert ! ( res. is_err( ) ) ;
1251
+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1252
+ assert ! ( res. is_err( ) ) ;
1253
+
1254
+ // Squeeze 0d to 0d
1255
+ let d = Dim ( [ ] ) ;
1256
+ let s = Dim ( [ ] ) ;
1257
+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1258
+ assert ! ( res. is_ok( ) ) ;
1259
+ // grow 0d to 2d
1260
+ let dans = Dim ( [ 1 , 1 ] ) ;
1261
+ let sans = Dim ( [ 1 , 1 ] ) ;
1262
+ let ( d2, s2) = squeeze_into :: < _ , Ix2 > ( & d, & s) . unwrap ( ) ;
1263
+ assert_eq ! ( d2, dans) ;
1264
+ assert_eq ! ( s2, sans) ;
1265
+
1266
+ // Squeeze 0d to 0d dynamic
1267
+ let d = dyndim ( & [ ] ) ;
1268
+ let s = dyndim ( & [ ] ) ;
1269
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1270
+ let dans = d;
1271
+ let sans = s;
1272
+ assert_eq ! ( d2, dans) ;
1273
+ assert_eq ! ( s2, sans) ;
1274
+ }
1275
+
1151
1276
#[ test]
1152
1277
fn test_merge_axes_from_the_back ( ) {
1153
1278
let dyndim = Dim :: < & [ usize ] > ;
0 commit comments