@@ -34,12 +34,17 @@ static void scatter_gather_dtype_check(
34
34
// Test:
35
35
// 1. index.size(d) == self.size(d) for all d != dim
36
36
// 2. index.size(d) <= src.size(d) for all d != dim
37
+ // 3. index.dim() == self.dim() == src.dim()
37
38
static void gather_shape_check (const Tensor& self, int64_t dim,
38
39
const Tensor& index, const Tensor& src
39
40
) {
40
41
auto self_dims = ensure_nonempty_dim (self.dim ());
41
-
42
42
TORCH_CHECK (self_dims == ensure_nonempty_dim (index.dim ()),
43
+ " Index tensor must have the same number of dimensions as out tensor"
44
+ );
45
+
46
+ auto src_dims = ensure_nonempty_dim (src.dim ());
47
+ TORCH_CHECK (src_dims == ensure_nonempty_dim (index.dim ()),
43
48
" Index tensor must have the same number of dimensions as input tensor"
44
49
);
45
50
@@ -66,10 +71,16 @@ static void gather_shape_check(const Tensor& self, int64_t dim,
66
71
// Tests:
67
72
// 1. index.size(d) <= self.size(d) for all d != dim
68
73
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
74
+ // 3. index.dim() == self.dim() == src.dim()
69
75
static void scatter_shape_check (
70
76
const Tensor& self, int64_t dim, const Tensor& index,
71
77
const c10::optional<Tensor>& src_opt = c10::nullopt
72
78
) {
79
+ TORCH_CHECK (
80
+ ensure_nonempty_dim (self.dim ()) == ensure_nonempty_dim (index.dim ()),
81
+ " Index tensor must have the same number of dimensions as self tensor"
82
+ );
83
+
73
84
bool is_wrong_shape = false ;
74
85
int64_t self_dims = ensure_nonempty_dim (self.dim ());
75
86
@@ -97,6 +108,12 @@ static void scatter_shape_check(
97
108
98
109
if (src_opt.has_value ()) {
99
110
auto src = src_opt.value ();
111
+
112
+ TORCH_CHECK (
113
+ ensure_nonempty_dim (src.dim ()) == ensure_nonempty_dim (index.dim ()),
114
+ " Index tensor must have the same number of dimensions as src tensor"
115
+ );
116
+
100
117
TORCH_CHECK (!is_wrong_shape,
101
118
" Expected index " , index.sizes (),
102
119
" to be smaller than self " , self.sizes (),
0 commit comments