@@ -573,6 +573,97 @@ fn diag()
573
573
assert_eq ! ( d. dim( ) , 1 ) ;
574
574
}
575
575
576
+ /// Check that the merged shape is correct.
577
+ ///
578
+ /// Note that this does not check the strides in the "merged" case!
579
+ #[ test]
580
+ fn merge_axes ( ) {
581
+ macro_rules! assert_merged {
582
+ ( $arr: expr, $slice: expr, $take: expr, $into: expr) => {
583
+ let mut v = $arr. slice( $slice) ;
584
+ let merged_len = v. len_of( Axis ( $take) ) * v. len_of( Axis ( $into) ) ;
585
+ assert!( v. merge_axes( Axis ( $take) , Axis ( $into) ) ) ;
586
+ assert_eq!( v. len_of( Axis ( $take) ) , if merged_len == 0 { 0 } else { 1 } ) ;
587
+ assert_eq!( v. len_of( Axis ( $into) ) , merged_len) ;
588
+ }
589
+ }
590
+ macro_rules! assert_not_merged {
591
+ ( $arr: expr, $slice: expr, $take: expr, $into: expr) => {
592
+ let mut v = $arr. slice( $slice) ;
593
+ let old_dim = v. raw_dim( ) ;
594
+ let old_strides = v. strides( ) . to_owned( ) ;
595
+ assert!( !v. merge_axes( Axis ( $take) , Axis ( $into) ) ) ;
596
+ assert_eq!( v. raw_dim( ) , old_dim) ;
597
+ assert_eq!( v. strides( ) , & old_strides[ ..] ) ;
598
+ }
599
+ }
600
+
601
+ let a = Array4 :: < u8 > :: zeros ( ( 3 , 4 , 5 , 4 ) ) ;
602
+
603
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 0 , 0 ) ;
604
+ assert_merged ! ( a, s![ .., .., .., ..] , 0 , 1 ) ;
605
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 0 , 2 ) ;
606
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 0 , 3 ) ;
607
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 1 , 0 ) ;
608
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 1 , 1 ) ;
609
+ assert_merged ! ( a, s![ .., .., .., ..] , 1 , 2 ) ;
610
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 1 , 3 ) ;
611
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 2 , 1 ) ;
612
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 2 , 2 ) ;
613
+ assert_merged ! ( a, s![ .., .., .., ..] , 2 , 3 ) ;
614
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 3 , 0 ) ;
615
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 3 , 1 ) ;
616
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 3 , 2 ) ;
617
+ assert_not_merged ! ( a, s![ .., .., .., ..] , 3 , 3 ) ;
618
+
619
+ assert_merged ! ( a, s![ .., .., .., ..; 2 ] , 0 , 1 ) ;
620
+ assert_not_merged ! ( a, s![ .., .., .., ..; 2 ] , 1 , 0 ) ;
621
+ assert_merged ! ( a, s![ .., .., .., ..; 2 ] , 1 , 2 ) ;
622
+ assert_not_merged ! ( a, s![ .., .., .., ..; 2 ] , 2 , 1 ) ;
623
+ assert_merged ! ( a, s![ .., .., .., ..; 2 ] , 2 , 3 ) ;
624
+ assert_not_merged ! ( a, s![ .., .., .., ..; 2 ] , 3 , 2 ) ;
625
+
626
+ assert_merged ! ( a, s![ .., .., .., ..3 ] , 0 , 1 ) ;
627
+ assert_not_merged ! ( a, s![ .., .., .., ..3 ] , 1 , 0 ) ;
628
+ assert_merged ! ( a, s![ .., .., .., ..3 ] , 1 , 2 ) ;
629
+ assert_not_merged ! ( a, s![ .., .., .., ..3 ] , 2 , 1 ) ;
630
+ assert_not_merged ! ( a, s![ .., .., .., ..3 ] , 2 , 3 ) ;
631
+
632
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 0 , 1 ) ;
633
+ assert_not_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 0 ) ;
634
+ assert_not_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 2 ) ;
635
+ assert_not_merged ! ( a, s![ .., .., ..; 2 , ..] , 2 , 3 ) ;
636
+
637
+ assert_merged ! ( a, s![ .., ..; 2 , .., ..] , 0 , 1 ) ;
638
+ assert_not_merged ! ( a, s![ .., ..; 2 , .., ..] , 1 , 0 ) ;
639
+ assert_not_merged ! ( a, s![ .., ..; 2 , .., ..] , 1 , 2 ) ;
640
+ assert_merged ! ( a, s![ .., ..; 2 , .., ..] , 2 , 3 ) ;
641
+ assert_not_merged ! ( a, s![ .., ..; 2 , .., ..] , 3 , 2 ) ;
642
+
643
+ let a = Array4 :: < u8 > :: zeros ( ( 3 , 1 , 5 , 1 ) . f ( ) ) ;
644
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 0 , 1 ) ;
645
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 0 , 3 ) ;
646
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 0 ) ;
647
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 1 ) ;
648
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 2 ) ;
649
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 3 ) ;
650
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 2 , 1 ) ;
651
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 2 , 3 ) ;
652
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 3 , 0 ) ;
653
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 3 , 1 ) ;
654
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 3 , 2 ) ;
655
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 3 , 3 ) ;
656
+
657
+ let a = Array4 :: < u8 > :: zeros ( ( 3 , 0 , 5 , 1 ) ) ;
658
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 0 , 1 ) ;
659
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 1 ) ;
660
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 2 , 1 ) ;
661
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 3 , 1 ) ;
662
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 0 ) ;
663
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 2 ) ;
664
+ assert_merged ! ( a, s![ .., .., ..; 2 , ..] , 1 , 3 ) ;
665
+ }
666
+
576
667
#[ test]
577
668
fn swapaxes ( )
578
669
{
0 commit comments