@@ -850,4 +850,217 @@ TEST(VmapTest, TestBatchedTensorPermute) {
850
850
}
851
851
}
852
852
853
+ static void checkMultiBatchVmapTransform (TensorList inputs, TensorList expected_outputs) {
854
+ auto outputs = MultiBatchVmapTransform::logicalToPhysical (inputs);
855
+ ASSERT_EQ (outputs.size (), expected_outputs.size ());
856
+ for (int64_t idx = 0 ; idx < outputs.size (); idx++) {
857
+ const auto & output = outputs[idx].tensor ();
858
+ ASSERT_EQ (output.data_ptr (), expected_outputs[idx].data_ptr ());
859
+ ASSERT_EQ (output.sizes (), expected_outputs[idx].sizes ());
860
+ ASSERT_TRUE (at::allclose (output, expected_outputs[idx]));
861
+ }
862
+ }
863
+
864
+ TEST (VmapTest, TestMultiBatchVmapTransformBatchedBatched) {
865
+ {
866
+ // Check that batch dims get moved to the front
867
+ int64_t B0 = 5 , B1 = 7 ;
868
+ Tensor x = at::randn ({2 , B0, 3 , B1});
869
+ Tensor y = at::randn ({B1, 2 , 3 , B0});
870
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 1 }, {/* lvl*/ 1 , /* dim*/ 3 }});
871
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 0 , /* dim*/ 3 }, {/* lvl*/ 1 , /* dim*/ 0 }});
872
+
873
+ checkMultiBatchVmapTransform (
874
+ {batched_x, batched_y},
875
+ {at::movedim (x, {1 , 3 }, {0 , 1 }), at::movedim (y, {0 , 3 }, {1 , 0 })});
876
+ }
877
+ {
878
+ // Check that batch dims become broadcasted and are present in all returns
879
+ int64_t B0 = 5 , B1 = 7 , B2 = 9 ;
880
+ Tensor x = at::randn ({B0, B2, 2 , 3 });
881
+ Tensor y = at::randn ({B0, B1, 2 , 3 });
882
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 0 }, {/* lvl*/ 2 , /* dim*/ 1 }});
883
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 0 , /* dim*/ 0 }, {/* lvl*/ 1 , /* dim*/ 1 }});
884
+
885
+ checkMultiBatchVmapTransform (
886
+ {batched_x, batched_y},
887
+ {x.unsqueeze (1 ).expand ({B0, B1, B2, 2 , 3 }), y.unsqueeze (2 ).expand ({B0, B1, B2, 2 , 3 })});
888
+ }
889
+ {
890
+ // Check operation on tensors of different logical dims
891
+ int64_t B0 = 5 ;
892
+ Tensor x = at::randn ({B0, 3 });
893
+ Tensor y = at::randn ({B0, 2 , 3 });
894
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 0 }});
895
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 0 , /* dim*/ 0 }});
896
+
897
+ checkMultiBatchVmapTransform ({batched_x, batched_y}, {x, y});
898
+ }
899
+ {
900
+ // More complicated example with two tensors.
901
+ int64_t B0 = 5 , B1 = 7 , B2 = 11 , B3 = 13 ;
902
+ Tensor x = at::randn ({2 , B0, 3 , B2});
903
+ Tensor y = at::randn ({B3, 3 , B1});
904
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 1 }, {/* lvl*/ 2 , /* dim*/ 3 }});
905
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 1 , /* dim*/ 2 }, {/* lvl*/ 3 , /* dim*/ 0 }});
906
+
907
+ checkMultiBatchVmapTransform (
908
+ {batched_x, batched_y},
909
+ {
910
+ x.permute ({1 , 3 , 0 , 2 }).view ({B0, 1 , B2, 1 , 2 , 3 }).expand ({B0, B1, B2, B3, 2 , 3 }),
911
+ y.permute ({2 , 0 , 1 }).view ({1 , B1, 1 , B3, 3 }).expand ({B0, B1, B2, B3, 3 }),
912
+ });
913
+ }
914
+ {
915
+ // Edge case: BatchedTensor "scalar" handling
916
+ int64_t B0 = 5 , B2 = 11 ;
917
+ Tensor x = at::randn ({B0});
918
+ Tensor y = at::randn ({B0, B2});
919
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 0 }});
920
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 0 , /* dim*/ 0 }, {/* lvl*/ 1 , /* dim*/ 1 }});
921
+
922
+ checkMultiBatchVmapTransform ({batched_x, batched_y}, {x.view ({B0, 1 }).expand ({B0, B2}), y});
923
+ checkMultiBatchVmapTransform ({batched_y, batched_x}, {y, x.view ({B0, 1 }).expand ({B0, B2})});
924
+ }
925
+ {
926
+ // Edge case: Only one tensor is a "batchedtensor scalar"
927
+ int64_t B0 = 5 , B2 = 11 ;
928
+ Tensor x = at::randn ({B0});
929
+ Tensor y = at::randn ({B0, B2, 2 });
930
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 0 }});
931
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 0 , /* dim*/ 0 }, {/* lvl*/ 1 , /* dim*/ 1 }});
932
+
933
+ checkMultiBatchVmapTransform ({batched_x, batched_y}, {x.view ({B0, 1 }).expand ({B0, B2}), y});
934
+ checkMultiBatchVmapTransform ({batched_y, batched_x}, {y, x.view ({B0, 1 }).expand ({B0, B2})});
935
+ }
936
+ }
937
+
938
+ TEST (VmapTest, TestMultiBatchVmapTransformBatchedUnbatched) {
939
+ {
940
+ // Check same example size
941
+ int64_t B0 = 5 , B1 = 7 ;
942
+ Tensor x = at::randn ({2 , B0, 3 , B1});
943
+ Tensor y = at::randn ({2 , 3 });
944
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 1 }, {/* lvl*/ 1 , /* dim*/ 3 }});
945
+
946
+ checkMultiBatchVmapTransform (
947
+ {batched_x, y},
948
+ {at::movedim (x, {1 , 3 }, {0 , 1 }), y.view ({1 , 1 , 2 , 3 }).expand ({B0, B1, 2 , 3 })});
949
+ checkMultiBatchVmapTransform (
950
+ {y, batched_x},
951
+ {y.view ({1 , 1 , 2 , 3 }).expand ({B0, B1, 2 , 3 }), at::movedim (x, {1 , 3 }, {0 , 1 })});
952
+ }
953
+ {
954
+ // BatchedTensor has higher example dim than non-batched-tensor
955
+ int64_t B0 = 5 , B1 = 7 ;
956
+ Tensor x = at::randn ({B0, B1, 2 , 3 });
957
+ Tensor y = at::randn ({3 });
958
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 0 }, {/* lvl*/ 1 , /* dim*/ 1 }});
959
+
960
+ checkMultiBatchVmapTransform (
961
+ {batched_x, y}, {x, y.view ({1 , 1 , 3 }).expand ({B0, B1, 3 })});
962
+ checkMultiBatchVmapTransform (
963
+ {y, batched_x}, {y.view ({1 , 1 , 3 }).expand ({B0, B1, 3 }), x});
964
+ }
965
+ {
966
+ // BatchedTensor has lower example dim than non-batched-tensor
967
+ int64_t B0 = 5 , B1 = 7 ;
968
+ Tensor x = at::randn ({B0, B1, 3 });
969
+ Tensor y = at::randn ({2 , 3 });
970
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 0 }, {/* lvl*/ 1 , /* dim*/ 1 }});
971
+
972
+ checkMultiBatchVmapTransform (
973
+ {batched_x, y}, {x.view ({B0, B1, 3 }), y.view ({1 , 1 , 2 , 3 }).expand ({B0, B1, 2 , 3 })});
974
+ checkMultiBatchVmapTransform (
975
+ {y, batched_x}, {y.view ({1 , 1 , 2 , 3 }).expand ({B0, B1, 2 , 3 }), x.view ({B0, B1, 3 })});
976
+ }
977
+ {
978
+ // Scalar handling
979
+ int64_t B0 = 5 , B1 = 7 ;
980
+ Tensor x = at::randn ({B0, B1});
981
+ Tensor y = at::randn ({});
982
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 0 }, {/* lvl*/ 1 , /* dim*/ 1 }});
983
+
984
+ checkMultiBatchVmapTransform ({batched_x, y}, {x, y.view ({1 , 1 }).expand ({B0, B1})});
985
+ checkMultiBatchVmapTransform ({y, batched_x}, {y.view ({1 , 1 }).expand ({B0, B1}), x});
986
+ }
987
+ }
988
+
989
+ TEST (VmapTest, TestMultiBatchVmapTransformMaxLevels) {
990
+ {
991
+ // inputs have all 64 levels
992
+ auto x = randn (std::vector<int64_t >(kVmapNumLevels , 1 ));
993
+ auto y = randn (std::vector<int64_t >(kVmapNumLevels , 1 ));
994
+ auto batched_x = makeBatched (x, maxBatchDimsAtFront ());
995
+ auto batched_y = makeBatched (y, maxBatchDimsAtFront ());
996
+
997
+ checkMultiBatchVmapTransform ({batched_x, batched_y}, {x, y});
998
+ }
999
+ {
1000
+ // inputs don't have all 64 levels, but results do.
1001
+ int64_t split = 19 ;
1002
+ auto x = randn (std::vector<int64_t >(split, 1 ));
1003
+ auto y = randn (std::vector<int64_t >(kVmapNumLevels - split, 1 ));
1004
+
1005
+ auto tmp = maxBatchDimsAtFront ();
1006
+ BatchDims x_bdims (tmp.begin (), tmp.begin () + split);
1007
+
1008
+ // Construct y_bdims.
1009
+ int64_t dim = 0 ;
1010
+ auto y_bdims_vector = fmap (
1011
+ ArrayRef<BatchDim>(tmp.begin () + split, tmp.end ()),
1012
+ [&](const BatchDim& bdim) -> BatchDim {
1013
+ return { bdim.level (), dim++ };
1014
+ });
1015
+ BatchDims y_bdims (y_bdims_vector.begin (), y_bdims_vector.end ());
1016
+
1017
+ auto batched_x = makeBatched (x, x_bdims);
1018
+ auto batched_y = makeBatched (y, y_bdims);
1019
+
1020
+ auto expected_size = std::vector<int64_t >(kVmapNumLevels , 1 );
1021
+ checkMultiBatchVmapTransform (
1022
+ {batched_x, batched_y},
1023
+ {x.view (expected_size), y.view (expected_size)});
1024
+ }
1025
+ }
1026
+
1027
+ TEST (VmapTest, TestMultiBatchVmapTransformMultipleTensors) {
1028
+ // Test with three (all batched) tensors
1029
+ {
1030
+ int64_t B0 = 5 , B1 = 7 , B2 = 9 ;
1031
+ Tensor x = at::randn ({2 , B0, 3 , B1});
1032
+ Tensor y = at::randn ({B1, 4 });
1033
+ Tensor z = at::randn ({2 , B2});
1034
+ Tensor batched_x = makeBatched (x, {{/* lvl*/ 0 , /* dim*/ 1 }, {/* lvl*/ 1 , /* dim*/ 3 }});
1035
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 1 , /* dim*/ 0 }});
1036
+ Tensor batched_z = makeBatched (z, {{/* lvl*/ 2 , /* dim*/ 1 }});
1037
+
1038
+ checkMultiBatchVmapTransform (
1039
+ {batched_x, batched_y, batched_z},
1040
+ {
1041
+ at::movedim (x, {1 , 3 }, {0 , 1 }).view ({B0, B1, 1 , 2 , 3 }).expand ({B0, B1, B2, 2 , 3 }),
1042
+ y.view ({1 , B1, 1 , 4 }).expand ({B0, B1, B2, 4 }),
1043
+ z.t ().view ({1 , 1 , B2, 2 }).expand ({B0, B1, B2, 2 }),
1044
+ });
1045
+ }
1046
+ // Test with three tensors, some batched, some unbatched
1047
+ {
1048
+ int64_t B0 = 5 , B1 = 7 , B2 = 9 ;
1049
+ Tensor x = at::randn ({2 , 3 });
1050
+ Tensor y = at::randn ({4 , B0});
1051
+ Tensor z = at::randn ({B1, 2 , B2});
1052
+ Tensor batched_y = makeBatched (y, {{/* lvl*/ 0 , /* dim*/ 1 }});
1053
+ Tensor batched_z = makeBatched (z, {{/* lvl*/ 1 , /* dim*/ 0 }, {/* lvl*/ 2 , /* dim*/ 2 }});
1054
+
1055
+ checkMultiBatchVmapTransform (
1056
+ {x, batched_y, batched_z},
1057
+ {
1058
+ x.view ({1 , 1 , 1 , 2 , 3 }).expand ({B0, B1, B2, 2 , 3 }),
1059
+ y.t ().view ({B0, 1 , 1 , 4 }).expand ({B0, B1, B2, 4 }),
1060
+ z.permute ({0 , 2 , 1 }).view ({1 , B1, B2, 2 }).expand ({B0, B1, B2, 2 }),
1061
+ });
1062
+ }
1063
+ }
1064
+
1065
+
853
1066
} // namespace
0 commit comments