@@ -55,11 +55,13 @@ def test_getitem(shape, dtype, data):
55
55
if i is None :
56
56
out_shape .append (1 )
57
57
else :
58
+ side = shape [a ]
58
59
if isinstance (i , int ):
60
+ if i < 0 :
61
+ i += side
59
62
axes_indices .append ([i ])
60
63
else :
61
64
assert isinstance (i , slice ) # sanity check
62
- side = shape [a ]
63
65
indices = range (side )[i ]
64
66
axes_indices .append (indices )
65
67
out_shape .append (len (indices ))
@@ -102,9 +104,9 @@ def test_setitem(shape, dtypes, data):
102
104
start_pos = _key .index (Ellipsis )
103
105
_key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
104
106
out_shape = []
105
- for a , i in enumerate (_key ):
107
+
108
+ for i , side in zip (_key , shape ):
106
109
if isinstance (i , slice ):
107
- side = shape [a ]
108
110
indices = range (side )[i ]
109
111
out_shape .append (len (indices ))
110
112
out_shape = tuple (out_shape )
@@ -119,7 +121,8 @@ def test_setitem(shape, dtypes, data):
119
121
120
122
ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
121
123
ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
122
- f_res = f"res[{ sh .fmt_idx ('x' , key )} ]"
124
+
125
+ f_res = sh .fmt_idx ("x" , key )
123
126
if isinstance (value , get_args (Scalar )):
124
127
msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
125
128
if math .isnan (value ):
@@ -128,14 +131,21 @@ def test_setitem(shape, dtypes, data):
128
131
assert res [key ] == value , msg
129
132
else :
130
133
ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
131
- if all (isinstance (i , int ) for i in _key ): # TODO: normalise slices and ellipsis
132
- _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
133
- unaffected_indices = list (sh .ndindex (res .shape ))
134
- unaffected_indices .remove (_key )
135
- for idx in unaffected_indices :
136
- ph .assert_0d_equals (
137
- "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
138
- )
134
+
135
+ axes_indices = []
136
+ for i , side in zip (_key , shape ):
137
+ if isinstance (i , int ):
138
+ if i < 0 :
139
+ i += side
140
+ axes_indices .append ([i ])
141
+ else :
142
+ indices = range (side )[i ]
143
+ axes_indices .append (indices )
144
+ unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
145
+ for idx in unaffected_indices :
146
+ ph .assert_0d_equals (
147
+ "__setitem__" , f"old { f_res } " , x [idx ], f"modified { f_res } " , res [idx ]
148
+ )
139
149
140
150
141
151
@pytest .mark .data_dependent_shapes
0 commit comments