Skip to content

Commit 1994ab1

Browse files
zou3519facebook-github-bot
authored andcommitted
Optimize alignBatchDimsAtFront (pytorch#41941)
Summary: Pull Request resolved: pytorch#41941 If we know that the tensor already has the desired aligned size, we don't need to put in the effort to align it. Test Plan: - `./build/bin/vmap_test`, `pytest test/test_vmap.py -v` Reviewed By: albanD Differential Revision: D22764101 Pulled By: zou3519 fbshipit-source-id: a2ab7ce7b98d405ae905f7fd98db097210bfad65
1 parent 5124436 commit 1994ab1

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

aten/src/ATen/VmapTransforms.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,13 @@ static Tensor alignBatchDimsAtFront(
157157
auto tensor_example_dim = physical_sizes.size() - /*num_batch_dims*/tensor_levels.count();
158158
TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
159159

160-
std::vector<int64_t> aligned_sizes(requested_levels.count() + requested_example_dim, 1);
160+
if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
161+
// Optimization: no need to do another view if the physical tensor is
162+
// already the correct shape
163+
return physical_tensor;
164+
}
165+
166+
VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
161167

162168
// align the example dims (non-bdims dims) first
163169
// aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]

0 commit comments

Comments
 (0)