Skip to content

Commit 7c7c9c3

Browse files
gchanannikitaved
andauthored
scatter/gather - check that inputs are of the same dimensionality (pytorch#41890)
Co-authored-by: Nikita Vedeneev <[email protected]>
1 parent a2922f5 commit 7c7c9c3

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

aten/src/ATen/native/ScatterGatherChecks.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@ static void scatter_gather_dtype_check(
3434
// Test:
3535
// 1. index.size(d) == self.size(d) for all d != dim
3636
// 2. index.size(d) <= src.size(d) for all d != dim
37+
// 3. index.dim() == self.dim() == src.dim()
3738
static void gather_shape_check(const Tensor& self, int64_t dim,
3839
const Tensor& index, const Tensor& src
3940
) {
4041
auto self_dims = ensure_nonempty_dim(self.dim());
41-
4242
TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
43+
"Index tensor must have the same number of dimensions as out tensor"
44+
);
45+
46+
auto src_dims = ensure_nonempty_dim(src.dim());
47+
TORCH_CHECK(src_dims == ensure_nonempty_dim(index.dim()),
4348
"Index tensor must have the same number of dimensions as input tensor"
4449
);
4550

@@ -66,10 +71,16 @@ static void gather_shape_check(const Tensor& self, int64_t dim,
6671
// Tests:
6772
// 1. index.size(d) <= self.size(d) for all d != dim
6873
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
74+
// 3. index.dim() == self.dim() == src.dim()
6975
static void scatter_shape_check(
7076
const Tensor& self, int64_t dim, const Tensor& index,
7177
const c10::optional<Tensor>& src_opt = c10::nullopt
7278
) {
79+
TORCH_CHECK(
80+
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
81+
"Index tensor must have the same number of dimensions as self tensor"
82+
);
83+
7384
bool is_wrong_shape = false;
7485
int64_t self_dims = ensure_nonempty_dim(self.dim());
7586

@@ -97,6 +108,12 @@ static void scatter_shape_check(
97108

98109
if (src_opt.has_value()) {
99110
auto src = src_opt.value();
111+
112+
TORCH_CHECK(
113+
ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
114+
"Index tensor must have the same number of dimensions as src tensor"
115+
);
116+
100117
TORCH_CHECK(!is_wrong_shape,
101118
"Expected index ", index.sizes(),
102119
" to be smaller than self ", self.sizes(),

test/test_torch.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2643,6 +2643,13 @@ def _test_gather(self, cast, test_bounds=True):
26432643
with self.assertRaisesRegex(RuntimeError, 'Expected self.dtype to be equal to src.dtype'):
26442644
torch.gather(src, dim, idx, out=expected.to(torch.int))
26452645

2646+
# checks for the same dimensionality
2647+
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as input tensor'):
2648+
torch.gather(src, dim, idx.unsqueeze(-1))
2649+
2650+
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as input tensor'):
2651+
torch.gather(src.unsqueeze(-1), dim, idx)
2652+
26462653
if test_bounds:
26472654
idx[0][0][0] = 23
26482655
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
@@ -2728,6 +2735,17 @@ def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True, *,
27282735
with self.assertRaisesRegex(RuntimeError, 'Expected dtype int64 for index'):
27292736
getattr(base.clone(), method)(dim, idx.type(torch.int), src)
27302737

2738+
# check for the same dimensionality
2739+
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as self tensor'):
2740+
getattr(base.clone().unsqueeze(-1), method)(dim, idx, src)
2741+
2742+
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as self tensor'):
2743+
getattr(base.clone(), method)(dim, idx.unsqueeze(-1), src)
2744+
2745+
if not is_scalar:
2746+
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as src tensor'):
2747+
getattr(base.clone(), method)(dim, idx, src.unsqueeze(-1))
2748+
27312749
if test_bounds:
27322750
idx[0][0][0] = 34
27332751
with self.assertRaises(RuntimeError):

0 commit comments

Comments
 (0)