@@ -762,6 +762,53 @@ where
762
762
* strides = new_strides;
763
763
}
764
764
765
+ /// Remove axes with length one, except never removing the last axis.
766
+ pub ( crate ) fn squeeze_into < D , E > ( dim : & D , strides : & D ) -> Result < ( E , E ) , ShapeError >
767
+ where
768
+ D : Dimension ,
769
+ E : Dimension ,
770
+ {
771
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
772
+
773
+ // Count axes with dim == 1; we keep axes with d == 0 or d > 1
774
+ let mut ndim_new = 0 ;
775
+ for & d in dim. slice ( ) {
776
+ if d != 1 { ndim_new += 1 ; }
777
+ }
778
+ let mut fill_ones = 0 ;
779
+ if let Some ( e_ndim) = E :: NDIM {
780
+ if e_ndim < ndim_new {
781
+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
782
+ }
783
+ fill_ones = e_ndim - ndim_new;
784
+ ndim_new = e_ndim;
785
+ } else {
786
+ // dynamic-dimensional
787
+ // use minimum one dimension unless input has less than one dim
788
+ if dim. ndim ( ) > 0 && ndim_new == 0 {
789
+ ndim_new = 1 ;
790
+ fill_ones = 1 ;
791
+ }
792
+ }
793
+
794
+ let mut new_dim = E :: zeros ( ndim_new) ;
795
+ let mut new_strides = E :: zeros ( ndim_new) ;
796
+ let mut i = 0 ;
797
+ while i < fill_ones {
798
+ new_dim[ i] = 1 ;
799
+ new_strides[ i] = 1 ;
800
+ i += 1 ;
801
+ }
802
+ for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
803
+ if d != 1 {
804
+ new_dim[ i] = d;
805
+ new_strides[ i] = s;
806
+ i += 1 ;
807
+ }
808
+ }
809
+ Ok ( ( new_dim, new_strides) )
810
+ }
811
+
765
812
766
813
/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
767
814
/// stride
@@ -1148,6 +1195,91 @@ mod test {
1148
1195
assert_eq ! ( s, sans) ;
1149
1196
}
1150
1197
1198
+ #[ test]
1199
+ #[ cfg( feature = "std" ) ]
1200
+ fn test_squeeze_into ( ) {
1201
+ use super :: squeeze_into;
1202
+
1203
+ let dyndim = Dim :: < & [ usize ] > ;
1204
+
1205
+ // squeeze to ixdyn
1206
+ let d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1207
+ let s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1208
+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1209
+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1210
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1211
+ assert_eq ! ( d2, dans) ;
1212
+ assert_eq ! ( s2, sans) ;
1213
+
1214
+ // squeeze to ixdyn does not go below 1D
1215
+ let d = dyndim ( & [ 1 , 1 ] ) ;
1216
+ let s = dyndim ( & [ 3 , 4 ] ) ;
1217
+ let dans = dyndim ( & [ 1 ] ) ;
1218
+ let sans = dyndim ( & [ 1 ] ) ;
1219
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1220
+ assert_eq ! ( d2, dans) ;
1221
+ assert_eq ! ( s2, sans) ;
1222
+
1223
+ let d = Dim ( [ 1 , 1 ] ) ;
1224
+ let s = Dim ( [ 3 , 4 ] ) ;
1225
+ let dans = Dim ( [ 1 ] ) ;
1226
+ let sans = Dim ( [ 1 ] ) ;
1227
+ let ( d2, s2) = squeeze_into :: < _ , Ix1 > ( & d, & s) . unwrap ( ) ;
1228
+ assert_eq ! ( d2, dans) ;
1229
+ assert_eq ! ( s2, sans) ;
1230
+
1231
+ // squeeze to zero-dim
1232
+ let ( d2, s2) = squeeze_into :: < _ , Ix0 > ( & d, & s) . unwrap ( ) ;
1233
+ assert_eq ! ( d2, Ix0 ( ) ) ;
1234
+ assert_eq ! ( s2, Ix0 ( ) ) ;
1235
+
1236
+ let d = Dim ( [ 0 , 1 , 3 , 4 ] ) ;
1237
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1238
+ let dans = Dim ( [ 0 , 3 , 4 ] ) ;
1239
+ let sans = Dim ( [ 2 , 4 , 5 ] ) ;
1240
+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1241
+ assert_eq ! ( d2, dans) ;
1242
+ assert_eq ! ( s2, sans) ;
1243
+
1244
+ // Pad with ones
1245
+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1246
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1247
+ let dans = Dim ( [ 1 , 0 , 3 ] ) ;
1248
+ let sans = Dim ( [ 1 , 2 , 4 ] ) ;
1249
+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1250
+ assert_eq ! ( d2, dans) ;
1251
+ assert_eq ! ( s2, sans) ;
1252
+
1253
+ // Try something that doesn't fit
1254
+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1255
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1256
+ let res = squeeze_into :: < _ , Ix1 > ( & d, & s) ;
1257
+ assert ! ( res. is_err( ) ) ;
1258
+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1259
+ assert ! ( res. is_err( ) ) ;
1260
+
1261
+ // Squeeze 0d to 0d
1262
+ let d = Dim ( [ ] ) ;
1263
+ let s = Dim ( [ ] ) ;
1264
+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1265
+ assert ! ( res. is_ok( ) ) ;
1266
+ // grow 0d to 2d
1267
+ let dans = Dim ( [ 1 , 1 ] ) ;
1268
+ let sans = Dim ( [ 1 , 1 ] ) ;
1269
+ let ( d2, s2) = squeeze_into :: < _ , Ix2 > ( & d, & s) . unwrap ( ) ;
1270
+ assert_eq ! ( d2, dans) ;
1271
+ assert_eq ! ( s2, sans) ;
1272
+
1273
+ // Squeeze 0d to 0d dynamic
1274
+ let d = dyndim ( & [ ] ) ;
1275
+ let s = dyndim ( & [ ] ) ;
1276
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1277
+ let dans = d;
1278
+ let sans = s;
1279
+ assert_eq ! ( d2, dans) ;
1280
+ assert_eq ! ( s2, sans) ;
1281
+ }
1282
+
1151
1283
#[ test]
1152
1284
fn test_merge_axes_from_the_back ( ) {
1153
1285
let dyndim = Dim :: < & [ usize ] > ;
0 commit comments