1
1
import math
2
2
from itertools import product
3
- from typing import List , Union , get_args
3
+ from typing import List , Sequence , Tuple , Union , get_args
4
4
5
5
import pytest
6
6
from hypothesis import assume , given , note
@@ -28,7 +28,7 @@ def scalar_objects(
28
28
)
29
29
30
30
31
- def normalise_key (key : Index , shape : Shape ):
31
+ def normalise_key (key : Index , shape : Shape ) -> Tuple [ Union [ int , slice ], ...] :
32
32
"""
33
33
Normalise an indexing key.
34
34
@@ -46,40 +46,52 @@ def normalise_key(key: Index, shape: Shape):
46
46
return _key
47
47
48
48
49
- @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
50
- def test_getitem (shape , dtype , data ):
51
- zero_sided = any (side == 0 for side in shape )
52
- if zero_sided :
53
- x = xp .zeros (shape , dtype = dtype )
54
- else :
55
- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
56
- x = xp .asarray (obj , dtype = dtype )
57
- note (f"{ x = } " )
58
- key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
59
-
60
- out = x [key ]
49
+ def get_indexed_axes_and_out_shape (
50
+ key : Tuple [Union [int , slice , None ], ...], shape : Shape
51
+ ) -> Tuple [Tuple [Sequence [int ], ...], Shape ]:
52
+ """
53
+ From the (normalised) key and input shape, calculates:
61
54
62
- ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
63
- _key = normalise_key (key , shape )
55
+ * indexed_axes: For each dimension, the axes which the key indexes.
56
+ * out_shape: The resulting shape of indexing an array (of the input shape)
57
+ with the key.
58
+ """
64
59
axes_indices = []
65
60
out_shape = []
66
61
a = 0
67
- for i in _key :
62
+ for i in key :
68
63
if i is None :
69
64
out_shape .append (1 )
70
65
else :
71
66
side = shape [a ]
72
67
if isinstance (i , int ):
73
68
if i < 0 :
74
69
i += side
75
- axes_indices .append ([ i ] )
70
+ axes_indices .append (( i ,) )
76
71
else :
77
- assert isinstance (i , slice ) # sanity check
78
72
indices = range (side )[i ]
79
73
axes_indices .append (indices )
80
74
out_shape .append (len (indices ))
81
75
a += 1
82
- out_shape = tuple (out_shape )
76
+ return tuple (axes_indices ), tuple (out_shape )
77
+
78
+
79
+ @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
80
+ def test_getitem (shape , dtype , data ):
81
+ zero_sided = any (side == 0 for side in shape )
82
+ if zero_sided :
83
+ x = xp .zeros (shape , dtype = dtype )
84
+ else :
85
+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
86
+ x = xp .asarray (obj , dtype = dtype )
87
+ note (f"{ x = } " )
88
+ key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
89
+
90
+ out = x [key ]
91
+
92
+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
93
+ _key = normalise_key (key , shape )
94
+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
83
95
ph .assert_shape ("__getitem__" , out .shape , out_shape )
84
96
out_zero_sided = any (side == 0 for side in out_shape )
85
97
if not zero_sided and not out_zero_sided :
@@ -109,13 +121,7 @@ def test_setitem(shape, dtypes, data):
109
121
note (f"{ x = } " )
110
122
key = data .draw (xps .indices (shape = shape ), label = "key" )
111
123
_key = normalise_key (key , shape )
112
- out_shape = []
113
-
114
- for i , side in zip (_key , shape ):
115
- if isinstance (i , slice ):
116
- indices = range (side )[i ]
117
- out_shape .append (len (indices ))
118
- out_shape = tuple (out_shape )
124
+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
119
125
value_strat = xps .arrays (dtype = dtypes .result_dtype , shape = out_shape )
120
126
if out_shape == ():
121
127
# We can pass scalars if we're only indexing one element
@@ -127,7 +133,6 @@ def test_setitem(shape, dtypes, data):
127
133
128
134
ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
129
135
ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
130
-
131
136
f_res = sh .fmt_idx ("x" , key )
132
137
if isinstance (value , get_args (Scalar )):
133
138
msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
@@ -137,16 +142,6 @@ def test_setitem(shape, dtypes, data):
137
142
assert res [key ] == value , msg
138
143
else :
139
144
ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
140
-
141
- axes_indices = []
142
- for i , side in zip (_key , shape ):
143
- if isinstance (i , int ):
144
- if i < 0 :
145
- i += side
146
- axes_indices .append ([i ])
147
- else :
148
- indices = range (side )[i ]
149
- axes_indices .append (indices )
150
145
unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
151
146
for idx in unaffected_indices :
152
147
ph .assert_0d_equals (
0 commit comments