Skip to content

Commit 131037d

Browse files
committed
normalise_key() util for indexing tests
1 parent a4dd075 commit 131037d

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

array_api_tests/test_array_object.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from . import shape_helpers as sh
1414
from . import xps
1515
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
16-
from .typing import DataType, Param, Scalar, ScalarType, Shape
16+
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
1717

1818
pytestmark = pytest.mark.ci
1919

@@ -28,6 +28,24 @@ def scalar_objects(
2828
)
2929

3030

31+
def normalise_key(key: Index, shape: Shape):
32+
"""
33+
Normalise an indexing key.
34+
35+
* If a non-tuple index, wrap as a tuple.
36+
* Represent ellipsis as equivalent slices.
37+
"""
38+
_key = tuple(key) if isinstance(key, tuple) else (key,)
39+
if Ellipsis in _key:
40+
nonexpanding_key = tuple(i for i in _key if i is not None)
41+
start_a = nonexpanding_key.index(Ellipsis)
42+
stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1))
43+
slices = tuple(slice(None) for _ in range(start_a, stop_a))
44+
start_pos = _key.index(Ellipsis)
45+
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
46+
return _key
47+
48+
3149
@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
3250
def test_getitem(shape, dtype, data):
3351
zero_sided = any(side == 0 for side in shape)
@@ -42,14 +60,7 @@ def test_getitem(shape, dtype, data):
4260
out = x[key]
4361

4462
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
45-
_key = tuple(key) if isinstance(key, tuple) else (key,)
46-
if Ellipsis in _key:
47-
nonexpanding_key = tuple(i for i in _key if i is not None)
48-
start_a = nonexpanding_key.index(Ellipsis)
49-
stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1))
50-
slices = tuple(slice(None) for _ in range(start_a, stop_a))
51-
start_pos = _key.index(Ellipsis)
52-
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
63+
_key = normalise_key(key, shape)
5364
axes_indices = []
5465
out_shape = []
5566
a = 0
@@ -97,14 +108,7 @@ def test_setitem(shape, dtypes, data):
97108
x = xp.asarray(obj, dtype=dtypes.result_dtype)
98109
note(f"{x=}")
99110
key = data.draw(xps.indices(shape=shape), label="key")
100-
_key = tuple(key) if isinstance(key, tuple) else (key,)
101-
if Ellipsis in _key:
102-
nonexpanding_key = tuple(i for i in _key if i is not None)
103-
start_a = nonexpanding_key.index(Ellipsis)
104-
stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1))
105-
slices = tuple(slice(None) for _ in range(start_a, stop_a))
106-
start_pos = _key.index(Ellipsis)
107-
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
111+
_key = normalise_key(key, shape)
108112
out_shape = []
109113

110114
for i, side in zip(_key, shape):

array_api_tests/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
ScalarType = Union[Type[bool], Type[int], Type[float]]
1717
Array = Any
1818
Shape = Tuple[int, ...]
19-
AtomicIndex = Union[int, "ellipsis", slice] # noqa
19+
AtomicIndex = Union[int, "ellipsis", slice, None] # noqa
2020
Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]]
2121
Param = Tuple

0 commit comments

Comments
 (0)