Skip to content

Commit 373dd48

Browse files
committed
get_indexed_axes_and_out_shape() util for indexing tests
1 parent 131037d commit 373dd48

File tree

1 file changed

+33
-38
lines changed

1 file changed

+33
-38
lines changed

array_api_tests/test_array_object.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import List, Union, get_args
3+
from typing import List, Sequence, Tuple, Union, get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -28,7 +28,7 @@ def scalar_objects(
2828
)
2929

3030

31-
def normalise_key(key: Index, shape: Shape):
31+
def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]:
3232
"""
3333
Normalise an indexing key.
3434
@@ -46,40 +46,52 @@ def normalise_key(key: Index, shape: Shape):
4646
return _key
4747

4848

49-
@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
50-
def test_getitem(shape, dtype, data):
51-
zero_sided = any(side == 0 for side in shape)
52-
if zero_sided:
53-
x = xp.zeros(shape, dtype=dtype)
54-
else:
55-
obj = data.draw(scalar_objects(dtype, shape), label="obj")
56-
x = xp.asarray(obj, dtype=dtype)
57-
note(f"{x=}")
58-
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
59-
60-
out = x[key]
49+
def get_indexed_axes_and_out_shape(
50+
key: Tuple[Union[int, slice, None], ...], shape: Shape
51+
) -> Tuple[Tuple[Sequence[int], ...], Shape]:
52+
"""
53+
From the (normalised) key and input shape, calculates:
6154
62-
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
63-
_key = normalise_key(key, shape)
55+
* indexed_axes: For each dimension, the axes which the key indexes.
56+
* out_shape: The resulting shape of indexing an array (of the input shape)
57+
with the key.
58+
"""
6459
axes_indices = []
6560
out_shape = []
6661
a = 0
67-
for i in _key:
62+
for i in key:
6863
if i is None:
6964
out_shape.append(1)
7065
else:
7166
side = shape[a]
7267
if isinstance(i, int):
7368
if i < 0:
7469
i += side
75-
axes_indices.append([i])
70+
axes_indices.append((i,))
7671
else:
77-
assert isinstance(i, slice) # sanity check
7872
indices = range(side)[i]
7973
axes_indices.append(indices)
8074
out_shape.append(len(indices))
8175
a += 1
82-
out_shape = tuple(out_shape)
76+
return tuple(axes_indices), tuple(out_shape)
77+
78+
79+
@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
80+
def test_getitem(shape, dtype, data):
81+
zero_sided = any(side == 0 for side in shape)
82+
if zero_sided:
83+
x = xp.zeros(shape, dtype=dtype)
84+
else:
85+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
86+
x = xp.asarray(obj, dtype=dtype)
87+
note(f"{x=}")
88+
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
89+
90+
out = x[key]
91+
92+
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
93+
_key = normalise_key(key, shape)
94+
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
8395
ph.assert_shape("__getitem__", out.shape, out_shape)
8496
out_zero_sided = any(side == 0 for side in out_shape)
8597
if not zero_sided and not out_zero_sided:
@@ -109,13 +121,7 @@ def test_setitem(shape, dtypes, data):
109121
note(f"{x=}")
110122
key = data.draw(xps.indices(shape=shape), label="key")
111123
_key = normalise_key(key, shape)
112-
out_shape = []
113-
114-
for i, side in zip(_key, shape):
115-
if isinstance(i, slice):
116-
indices = range(side)[i]
117-
out_shape.append(len(indices))
118-
out_shape = tuple(out_shape)
124+
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
119125
value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape)
120126
if out_shape == ():
121127
# We can pass scalars if we're only indexing one element
@@ -127,7 +133,6 @@ def test_setitem(shape, dtypes, data):
127133

128134
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
129135
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape")
130-
131136
f_res = sh.fmt_idx("x", key)
132137
if isinstance(value, get_args(Scalar)):
133138
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
@@ -137,16 +142,6 @@ def test_setitem(shape, dtypes, data):
137142
assert res[key] == value, msg
138143
else:
139144
ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res)
140-
141-
axes_indices = []
142-
for i, side in zip(_key, shape):
143-
if isinstance(i, int):
144-
if i < 0:
145-
i += side
146-
axes_indices.append([i])
147-
else:
148-
indices = range(side)[i]
149-
axes_indices.append(indices)
150145
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
151146
for idx in unaffected_indices:
152147
ph.assert_0d_equals(

0 commit comments

Comments
 (0)