Skip to content

Commit 0571cfd

Browse files
zou3519facebook-github-bot
authored andcommitted
Implement MultiBatchVmapTransform::logicalToPhysical(TensorList) (pytorch#41942)
Summary: Pull Request resolved: pytorch#41942 This function: - permutes all batch dims to the front of the tensors - aligns all the batch dims to the collective levels of all the tensors - expands all of the batch dims such that they are present in each of the result tensors This function is useful for the next diff up on the stack (which is implementing a fallback kernel for BatchedTensor). It's also useful in general for implementing batching rules on operators that take in multiple batch dimensions at the front of each tensor (but we don't have too many of those in PyTorch). Test Plan: - `./build/bin/vmap_test` Reviewed By: ezyang Differential Revision: D22764104 Pulled By: zou3519 fbshipit-source-id: d42cc8824a1bcf258687de164b7853af52852f53
1 parent 1994ab1 commit 0571cfd

File tree

3 files changed

+271
-6
lines changed

3 files changed

+271
-6
lines changed

aten/src/ATen/VmapTransforms.cpp

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logica
5555
return { permuteBatchDimsToFront(batched), createLevelsBitset(batched->bdims()) };
5656
}
5757

58-
std::vector<VmapPhysicalView>
59-
MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) {
60-
TORCH_INTERNAL_ASSERT(false, "NYI");
61-
}
62-
6358
int64_t VmapPhysicalView::numBatchDims() const {
6459
return levels_.count();
6560
}
@@ -186,6 +181,63 @@ static Tensor alignBatchDimsAtFront(
186181
return physical_tensor.view(aligned_sizes);
187182
}
188183

184+
// The algorithm is as follows:
185+
// 1. Figure out what all of the collective levels in `logical_tensors` is.
186+
// 2. Move all batch dims to the front of the tensors and add extra dims
187+
// of size 1. At this point, every tensor will have a dimension for
188+
// each of the collective levels.
189+
// 3. Compute the batch_sizes.
190+
// 4. Expand each physical tensor so that they have output batch size equal
191+
// to `batch_sizes`
192+
VmapPhysicalViewVec
193+
MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) {
194+
// Figure out all of the collective vmap levels in `logical_tensors`.
195+
std::bitset<kVmapNumLevels> collective_levels;
196+
for (const auto& logical_tensor : logical_tensors) {
197+
auto* batched = maybeGetBatched(logical_tensor);
198+
if (batched) {
199+
collective_levels |= createLevelsBitset(batched->bdims());
200+
}
201+
}
202+
203+
// Populate physical_tensors.
204+
// This contains a list of regular (non-Batched) Tensors where all of the
205+
// batch dims have been moved to the front of the tensor. Any previously
206+
// non-existing batch dims get added to the tensors as new dimensions of size 1.
207+
std::vector<Tensor> physical_tensors;
208+
int64_t num_batch_dims = collective_levels.count();
209+
for (const auto& logical_tensor : logical_tensors) {
210+
auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
211+
auto physical_tensor = alignBatchDimsAtFront(
212+
logical_tensor, collective_levels, requested_example_dim);
213+
physical_tensors.push_back(std::move(physical_tensor));
214+
}
215+
216+
// Compute batch_sizes
217+
VmapDimVector batch_sizes(num_batch_dims, 1);
218+
for (const auto& physical_tensor : physical_tensors) {
219+
auto physical_sizes = physical_tensor.sizes();
220+
for (int64_t dim = 0; dim < num_batch_dims; dim++) {
221+
if (physical_sizes[dim] != 1) {
222+
batch_sizes[dim] = physical_sizes[dim];
223+
}
224+
}
225+
}
226+
227+
// Expand each physical_tensor so that it has batch sizes `batch_sizes`
228+
VmapPhysicalViewVec result;
229+
for (const auto& physical_tensor : physical_tensors) {
230+
VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end());
231+
auto physical_sizes = physical_tensor.sizes();
232+
expanded_size.insert(
233+
expanded_size.end(),
234+
physical_sizes.begin() + num_batch_dims,
235+
physical_sizes.end());
236+
result.emplace_back(physical_tensor.expand(expanded_size), collective_levels);
237+
}
238+
return result;
239+
}
240+
189241
static std::pair<std::bitset<kVmapNumLevels>,int64_t>
190242
getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
191243
TORCH_INTERNAL_ASSERT(logical_tensors.size() > 0);

aten/src/ATen/VmapTransforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
5353
// and returns a VmapPhysicalView on the tensor(s).
5454
struct TORCH_API MultiBatchVmapTransform {
5555
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
56-
static std::vector<VmapPhysicalView> logicalToPhysical(TensorList logical_tensors);
56+
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
5757
};
5858

5959
// VmapTransform for operators that broadcast all inputs.

aten/src/ATen/test/vmap_test.cpp

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,4 +850,217 @@ TEST(VmapTest, TestBatchedTensorPermute) {
850850
}
851851
}
852852

853+
static void checkMultiBatchVmapTransform(TensorList inputs, TensorList expected_outputs) {
854+
auto outputs = MultiBatchVmapTransform::logicalToPhysical(inputs);
855+
ASSERT_EQ(outputs.size(), expected_outputs.size());
856+
for (int64_t idx = 0; idx < outputs.size(); idx++) {
857+
const auto& output = outputs[idx].tensor();
858+
ASSERT_EQ(output.data_ptr(), expected_outputs[idx].data_ptr());
859+
ASSERT_EQ(output.sizes(), expected_outputs[idx].sizes());
860+
ASSERT_TRUE(at::allclose(output, expected_outputs[idx]));
861+
}
862+
}
863+
864+
TEST(VmapTest, TestMultiBatchVmapTransformBatchedBatched) {
865+
{
866+
// Check that batch dims get moved to the front
867+
int64_t B0 = 5, B1 = 7;
868+
Tensor x = at::randn({2, B0, 3, B1});
869+
Tensor y = at::randn({B1, 2, 3, B0});
870+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
871+
Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/3}, {/*lvl*/1, /*dim*/0}});
872+
873+
checkMultiBatchVmapTransform(
874+
{batched_x, batched_y},
875+
{at::movedim(x, {1, 3}, {0, 1}), at::movedim(y, {0, 3}, {1, 0})});
876+
}
877+
{
878+
// Check that batch dims become broadcasted and are present in all returns
879+
int64_t B0 = 5, B1 = 7, B2 = 9;
880+
Tensor x = at::randn({B0, B2, 2, 3});
881+
Tensor y = at::randn({B0, B1, 2, 3});
882+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/2, /*dim*/1}});
883+
Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
884+
885+
checkMultiBatchVmapTransform(
886+
{batched_x, batched_y},
887+
{x.unsqueeze(1).expand({B0, B1, B2, 2, 3}), y.unsqueeze(2).expand({B0, B1, B2, 2, 3})});
888+
}
889+
{
890+
// Check operation on tensors of different logical dims
891+
int64_t B0 = 5;
892+
Tensor x = at::randn({B0, 3});
893+
Tensor y = at::randn({B0, 2, 3});
894+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
895+
Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}});
896+
897+
checkMultiBatchVmapTransform({batched_x, batched_y}, {x, y});
898+
}
899+
{
900+
// More complicated example with two tensors.
901+
int64_t B0 = 5, B1 = 7, B2 = 11, B3 = 13;
902+
Tensor x = at::randn({2, B0, 3, B2});
903+
Tensor y = at::randn({B3, 3, B1});
904+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/2, /*dim*/3}});
905+
Tensor batched_y = makeBatched(y, {{/*lvl*/1, /*dim*/2}, {/*lvl*/3, /*dim*/0}});
906+
907+
checkMultiBatchVmapTransform(
908+
{batched_x, batched_y},
909+
{
910+
x.permute({1, 3, 0, 2}).view({B0, 1, B2, 1, 2, 3}).expand({B0, B1, B2, B3, 2, 3}),
911+
y.permute({2, 0, 1}).view({1, B1, 1, B3, 3}).expand({B0, B1, B2, B3, 3}),
912+
});
913+
}
914+
{
915+
// Edge case: BatchedTensor "scalar" handling
916+
int64_t B0 = 5, B2 = 11;
917+
Tensor x = at::randn({B0});
918+
Tensor y = at::randn({B0, B2});
919+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
920+
Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
921+
922+
checkMultiBatchVmapTransform({batched_x, batched_y}, {x.view({B0, 1}).expand({B0, B2}), y});
923+
checkMultiBatchVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1}).expand({B0, B2})});
924+
}
925+
{
926+
// Edge case: Only one tensor is a "batchedtensor scalar"
927+
int64_t B0 = 5, B2 = 11;
928+
Tensor x = at::randn({B0});
929+
Tensor y = at::randn({B0, B2, 2});
930+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
931+
Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
932+
933+
checkMultiBatchVmapTransform({batched_x, batched_y}, {x.view({B0, 1}).expand({B0, B2}), y});
934+
checkMultiBatchVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1}).expand({B0, B2})});
935+
}
936+
}
937+
938+
TEST(VmapTest, TestMultiBatchVmapTransformBatchedUnbatched) {
939+
{
940+
// Check same example size
941+
int64_t B0 = 5, B1 = 7;
942+
Tensor x = at::randn({2, B0, 3, B1});
943+
Tensor y = at::randn({2, 3});
944+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
945+
946+
checkMultiBatchVmapTransform(
947+
{batched_x, y},
948+
{at::movedim(x, {1, 3}, {0, 1}), y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3})});
949+
checkMultiBatchVmapTransform(
950+
{y, batched_x},
951+
{y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3}), at::movedim(x, {1, 3}, {0, 1})});
952+
}
953+
{
954+
// BatchedTensor has higher example dim than non-batched-tensor
955+
int64_t B0 = 5, B1 = 7;
956+
Tensor x = at::randn({B0, B1, 2, 3});
957+
Tensor y = at::randn({3});
958+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
959+
960+
checkMultiBatchVmapTransform(
961+
{batched_x, y}, {x, y.view({1, 1, 3}).expand({B0, B1, 3})});
962+
checkMultiBatchVmapTransform(
963+
{y, batched_x}, {y.view({1, 1, 3}).expand({B0, B1, 3}), x});
964+
}
965+
{
966+
// BatchedTensor has lower example dim than non-batched-tensor
967+
int64_t B0 = 5, B1 = 7;
968+
Tensor x = at::randn({B0, B1, 3});
969+
Tensor y = at::randn({2, 3});
970+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
971+
972+
checkMultiBatchVmapTransform(
973+
{batched_x, y}, {x.view({B0, B1, 3}), y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3})});
974+
checkMultiBatchVmapTransform(
975+
{y, batched_x}, {y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3}), x.view({B0, B1, 3})});
976+
}
977+
{
978+
// Scalar handling
979+
int64_t B0 = 5, B1 = 7;
980+
Tensor x = at::randn({B0, B1});
981+
Tensor y = at::randn({});
982+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
983+
984+
checkMultiBatchVmapTransform({batched_x, y}, {x, y.view({1, 1}).expand({B0, B1})});
985+
checkMultiBatchVmapTransform({y, batched_x}, {y.view({1, 1}).expand({B0, B1}), x});
986+
}
987+
}
988+
989+
TEST(VmapTest, TestMultiBatchVmapTransformMaxLevels) {
990+
{
991+
// inputs have all 64 levels
992+
auto x = randn(std::vector<int64_t>(kVmapNumLevels, 1));
993+
auto y = randn(std::vector<int64_t>(kVmapNumLevels, 1));
994+
auto batched_x = makeBatched(x, maxBatchDimsAtFront());
995+
auto batched_y = makeBatched(y, maxBatchDimsAtFront());
996+
997+
checkMultiBatchVmapTransform({batched_x, batched_y}, {x, y});
998+
}
999+
{
1000+
// inputs don't have all 64 levels, but results do.
1001+
int64_t split = 19;
1002+
auto x = randn(std::vector<int64_t>(split, 1));
1003+
auto y = randn(std::vector<int64_t>(kVmapNumLevels - split, 1));
1004+
1005+
auto tmp = maxBatchDimsAtFront();
1006+
BatchDims x_bdims(tmp.begin(), tmp.begin() + split);
1007+
1008+
// Construct y_bdims.
1009+
int64_t dim = 0;
1010+
auto y_bdims_vector = fmap(
1011+
ArrayRef<BatchDim>(tmp.begin() + split, tmp.end()),
1012+
[&](const BatchDim& bdim) -> BatchDim {
1013+
return { bdim.level(), dim++ };
1014+
});
1015+
BatchDims y_bdims(y_bdims_vector.begin(), y_bdims_vector.end());
1016+
1017+
auto batched_x = makeBatched(x, x_bdims);
1018+
auto batched_y = makeBatched(y, y_bdims);
1019+
1020+
auto expected_size = std::vector<int64_t>(kVmapNumLevels, 1);
1021+
checkMultiBatchVmapTransform(
1022+
{batched_x, batched_y},
1023+
{x.view(expected_size), y.view(expected_size)});
1024+
}
1025+
}
1026+
1027+
TEST(VmapTest, TestMultiBatchVmapTransformMultipleTensors) {
1028+
// Test with three (all batched) tensors
1029+
{
1030+
int64_t B0 = 5, B1 = 7, B2 = 9;
1031+
Tensor x = at::randn({2, B0, 3, B1});
1032+
Tensor y = at::randn({B1, 4});
1033+
Tensor z = at::randn({2, B2});
1034+
Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
1035+
Tensor batched_y = makeBatched(y, {{/*lvl*/1, /*dim*/0}});
1036+
Tensor batched_z = makeBatched(z, {{/*lvl*/2, /*dim*/1}});
1037+
1038+
checkMultiBatchVmapTransform(
1039+
{batched_x, batched_y, batched_z},
1040+
{
1041+
at::movedim(x, {1, 3}, {0, 1}).view({B0, B1, 1, 2, 3}).expand({B0, B1, B2, 2, 3}),
1042+
y.view({1, B1, 1, 4}).expand({B0, B1, B2, 4}),
1043+
z.t().view({1, 1, B2, 2}).expand({B0, B1, B2, 2}),
1044+
});
1045+
}
1046+
// Test with three tensors, some batched, some unbatched
1047+
{
1048+
int64_t B0 = 5, B1 = 7, B2 = 9;
1049+
Tensor x = at::randn({2, 3});
1050+
Tensor y = at::randn({4, B0});
1051+
Tensor z = at::randn({B1, 2, B2});
1052+
Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/1}});
1053+
Tensor batched_z = makeBatched(z, {{/*lvl*/1, /*dim*/0}, {/*lvl*/2, /*dim*/2}});
1054+
1055+
checkMultiBatchVmapTransform(
1056+
{x, batched_y, batched_z},
1057+
{
1058+
x.view({1, 1, 1, 2, 3}).expand({B0, B1, B2, 2, 3}),
1059+
y.t().view({B0, 1, 1, 4}).expand({B0, B1, B2, 4}),
1060+
z.permute({0, 2, 1}).view({1, B1, B2, 2}).expand({B0, B1, B2, 2}),
1061+
});
1062+
}
1063+
}
1064+
1065+
8531066
} // namespace

0 commit comments

Comments
 (0)