12
12
from . import pytest_helpers as ph
13
13
from . import shape_helpers as sh
14
14
from . import xps
15
+ from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
15
16
from .typing import DataType , Scalar
16
17
17
18
pytestmark = pytest .mark .ci
@@ -245,11 +246,25 @@ def test_asarray_scalars(shape, data):
245
246
ph .assert_scalar_equals ("asarray" , scalar_type , idx , v , v_expect , ** kw )
246
247
247
248
248
- @given (xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes ()), st .data ())
249
- def test_asarray_arrays (x , data ):
250
- # TODO: test other valid dtypes
249
+ def scalar_eq (s1 : Scalar , s2 : Scalar ) -> bool :
250
+ if math .isnan (s1 ):
251
+ return math .isnan (s2 )
252
+ else :
253
+ return s1 == s2
254
+
255
+
256
+ @given (
257
+ shape = hh .shapes (),
258
+ dtypes = oneway_promotable_dtypes (dh .all_dtypes ),
259
+ data = st .data (),
260
+ )
261
+ def test_asarray_arrays (shape , dtypes , data ):
262
+ x = data .draw (xps .arrays (dtype = dtypes .input_dtype , shape = shape ), label = "x" )
263
+ dtypes_strat = st .just (dtypes .input_dtype )
264
+ if dtypes .input_dtype == dtypes .result_dtype :
265
+ dtypes_strat |= st .none ()
251
266
kw = data .draw (
252
- hh .kwargs (dtype = st . none () | st . just ( x . dtype ) , copy = st .none () | st .booleans ()),
267
+ hh .kwargs (dtype = dtypes_strat , copy = st .none () | st .booleans ()),
253
268
label = "kw" ,
254
269
)
255
270
@@ -261,27 +276,35 @@ def test_asarray_arrays(x, data):
261
276
else :
262
277
ph .assert_kw_dtype ("asarray" , dtype , out .dtype )
263
278
ph .assert_shape ("asarray" , out .shape , x .shape )
264
- if dtype is None or dtype == x .dtype :
265
- ph .assert_array_elements ("asarray" , out , x , ** kw )
266
- else :
267
- pass # TODO
279
+ ph .assert_array_elements ("asarray" , out , x , ** kw )
268
280
copy = kw .get ("copy" , None )
269
281
if copy is not None :
282
+ stype = dh .get_scalar_type (x .dtype )
270
283
idx = data .draw (xps .indices (x .shape , max_dims = 0 ), label = "mutating idx" )
271
- _dtype = x .dtype if dtype is None else dtype
272
- old_value = x [idx ]
284
+ old_value = stype (x [idx ])
285
+ scalar_strat = xps .from_dtype (dtypes .input_dtype ).filter (
286
+ lambda n : not scalar_eq (n , old_value )
287
+ )
273
288
value = data .draw (
274
- xps . arrays ( dtype = _dtype , shape = ()). filter (lambda y : y != old_value ),
289
+ scalar_strat | scalar_strat . map (lambda n : xp . asarray ( n , dtype = x . dtype ) ),
275
290
label = "mutating value" ,
276
291
)
277
292
x [idx ] = value
278
293
note (f"mutated { x = } " )
294
+ # sanity check
295
+ ph .assert_scalar_equals (
296
+ "__setitem__" , stype , idx , stype (x [idx ]), value , repr_name = "x"
297
+ )
298
+ new_out_value = stype (out [idx ])
299
+ f_out = f"{ sh .fmt_idx ('out' , idx )} ={ new_out_value } "
279
300
if copy :
280
- assert not xp .all (
281
- out == x
282
- ), f"xp.all(out == x)=True, but should be False after x was mutated\n { out = } "
283
- elif copy is False :
284
- pass # TODO
301
+ assert scalar_eq (
302
+ new_out_value , old_value
303
+ ), f"{ f_out } , but should be { old_value } even after x was mutated"
304
+ else :
305
+ assert scalar_eq (
306
+ new_out_value , value
307
+ ), f"{ f_out } , but should be { value } after x was mutated"
285
308
286
309
287
310
@given (hh .shapes (), hh .kwargs (dtype = st .none () | hh .shared_dtypes ))
0 commit comments