13
13
from . import shape_helpers as sh
14
14
from . import xps
15
15
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
16
- from .typing import DataType , Param , Scalar , ScalarType , Shape
16
+ from .typing import DataType , Index , Param , Scalar , ScalarType , Shape
17
17
18
18
pytestmark = pytest .mark .ci
19
19
@@ -28,6 +28,24 @@ def scalar_objects(
28
28
)
29
29
30
30
31
+ def normalise_key (key : Index , shape : Shape ):
32
+ """
33
+ Normalise an indexing key.
34
+
35
+ * If a non-tuple index, wrap as a tuple.
36
+ * Represent ellipsis as equivalent slices.
37
+ """
38
+ _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
39
+ if Ellipsis in _key :
40
+ nonexpanding_key = tuple (i for i in _key if i is not None )
41
+ start_a = nonexpanding_key .index (Ellipsis )
42
+ stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
43
+ slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
44
+ start_pos = _key .index (Ellipsis )
45
+ _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
46
+ return _key
47
+
48
+
31
49
@given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
32
50
def test_getitem (shape , dtype , data ):
33
51
zero_sided = any (side == 0 for side in shape )
@@ -42,14 +60,7 @@ def test_getitem(shape, dtype, data):
42
60
out = x [key ]
43
61
44
62
ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
45
- _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
46
- if Ellipsis in _key :
47
- nonexpanding_key = tuple (i for i in _key if i is not None )
48
- start_a = nonexpanding_key .index (Ellipsis )
49
- stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
50
- slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
51
- start_pos = _key .index (Ellipsis )
52
- _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
63
+ _key = normalise_key (key , shape )
53
64
axes_indices = []
54
65
out_shape = []
55
66
a = 0
@@ -97,14 +108,7 @@ def test_setitem(shape, dtypes, data):
97
108
x = xp .asarray (obj , dtype = dtypes .result_dtype )
98
109
note (f"{ x = } " )
99
110
key = data .draw (xps .indices (shape = shape ), label = "key" )
100
- _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
101
- if Ellipsis in _key :
102
- nonexpanding_key = tuple (i for i in _key if i is not None )
103
- start_a = nonexpanding_key .index (Ellipsis )
104
- stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
105
- slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
106
- start_pos = _key .index (Ellipsis )
107
- _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
111
+ _key = normalise_key (key , shape )
108
112
out_shape = []
109
113
110
114
for i , side in zip (_key , shape ):
0 commit comments