Skip to content

Commit b1dcf77

Browse files
committed
Test unaffected indices more wholly in test_setitem
1 parent 7feaa28 commit b1dcf77

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def assert_0d_equals(
301301
>>> x = xp.asarray([0, 1, 2])
302302
>>> res = xp.asarray(x, copy=True)
303303
>>> res[0] = 42
304-
>>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
304+
>>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])
305305
306306
is equivalent to
307307

array_api_tests/test_array_object.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ def test_getitem(shape, dtype, data):
5555
if i is None:
5656
out_shape.append(1)
5757
else:
58+
side = shape[a]
5859
if isinstance(i, int):
60+
if i < 0:
61+
i += side
5962
axes_indices.append([i])
6063
else:
6164
assert isinstance(i, slice) # sanity check
62-
side = shape[a]
6365
indices = range(side)[i]
6466
axes_indices.append(indices)
6567
out_shape.append(len(indices))
@@ -102,9 +104,9 @@ def test_setitem(shape, dtypes, data):
102104
start_pos = _key.index(Ellipsis)
103105
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
104106
out_shape = []
105-
for a, i in enumerate(_key):
107+
108+
for i, side in zip(_key, shape):
106109
if isinstance(i, slice):
107-
side = shape[a]
108110
indices = range(side)[i]
109111
out_shape.append(len(indices))
110112
out_shape = tuple(out_shape)
@@ -119,7 +121,8 @@ def test_setitem(shape, dtypes, data):
119121

120122
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
121123
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape")
122-
f_res = f"res[{sh.fmt_idx('x', key)}]"
124+
125+
f_res = sh.fmt_idx("x", key)
123126
if isinstance(value, get_args(Scalar)):
124127
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
125128
if math.isnan(value):
@@ -128,14 +131,21 @@ def test_setitem(shape, dtypes, data):
128131
assert res[key] == value, msg
129132
else:
130133
ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res)
131-
if all(isinstance(i, int) for i in _key): # TODO: normalise slices and ellipsis
132-
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
133-
unaffected_indices = list(sh.ndindex(res.shape))
134-
unaffected_indices.remove(_key)
135-
for idx in unaffected_indices:
136-
ph.assert_0d_equals(
137-
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
138-
)
134+
135+
axes_indices = []
136+
for i, side in zip(_key, shape):
137+
if isinstance(i, int):
138+
if i < 0:
139+
i += side
140+
axes_indices.append([i])
141+
else:
142+
indices = range(side)[i]
143+
axes_indices.append(indices)
144+
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
145+
for idx in unaffected_indices:
146+
ph.assert_0d_equals(
147+
"__setitem__", f"old {f_res}", x[idx], f"modified {f_res}", res[idx]
148+
)
139149

140150

141151
@pytest.mark.data_dependent_shapes

0 commit comments

Comments
 (0)