Skip to content

Commit 8a5103b

Browse files
committed
test_asarray_arrays improvements
* Test all possible dtype kwargs * Fix erroneous nan equals * Clean up copy testing
1 parent 52e835e commit 8a5103b

File tree

1 file changed

+39
-16
lines changed

1 file changed

+39
-16
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15+
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
1516
from .typing import DataType, Scalar
1617

1718
pytestmark = pytest.mark.ci
@@ -245,11 +246,25 @@ def test_asarray_scalars(shape, data):
245246
ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw)
246247

247248

248-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data())
249-
def test_asarray_arrays(x, data):
250-
# TODO: test other valid dtypes
249+
def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
250+
if math.isnan(s1):
251+
return math.isnan(s2)
252+
else:
253+
return s1 == s2
254+
255+
256+
@given(
257+
shape=hh.shapes(),
258+
dtypes=oneway_promotable_dtypes(dh.all_dtypes),
259+
data=st.data(),
260+
)
261+
def test_asarray_arrays(shape, dtypes, data):
262+
x = data.draw(xps.arrays(dtype=dtypes.input_dtype, shape=shape), label="x")
263+
dtypes_strat = st.just(dtypes.input_dtype)
264+
if dtypes.input_dtype == dtypes.result_dtype:
265+
dtypes_strat |= st.none()
251266
kw = data.draw(
252-
hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()),
267+
hh.kwargs(dtype=dtypes_strat, copy=st.none() | st.booleans()),
253268
label="kw",
254269
)
255270

@@ -261,27 +276,35 @@ def test_asarray_arrays(x, data):
261276
else:
262277
ph.assert_kw_dtype("asarray", dtype, out.dtype)
263278
ph.assert_shape("asarray", out.shape, x.shape)
264-
if dtype is None or dtype == x.dtype:
265-
ph.assert_array_elements("asarray", out, x, **kw)
266-
else:
267-
pass # TODO
279+
ph.assert_array_elements("asarray", out, x, **kw)
268280
copy = kw.get("copy", None)
269281
if copy is not None:
282+
stype = dh.get_scalar_type(x.dtype)
270283
idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
271-
_dtype = x.dtype if dtype is None else dtype
272-
old_value = x[idx]
284+
old_value = stype(x[idx])
285+
scalar_strat = xps.from_dtype(dtypes.input_dtype).filter(
286+
lambda n: not scalar_eq(n, old_value)
287+
)
273288
value = data.draw(
274-
xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value),
289+
scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)),
275290
label="mutating value",
276291
)
277292
x[idx] = value
278293
note(f"mutated {x=}")
294+
# sanity check
295+
ph.assert_scalar_equals(
296+
"__setitem__", stype, idx, stype(x[idx]), value, repr_name="x"
297+
)
298+
new_out_value = stype(out[idx])
299+
f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}"
279300
if copy:
280-
assert not xp.all(
281-
out == x
282-
), f"xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
283-
elif copy is False:
284-
pass # TODO
301+
assert scalar_eq(
302+
new_out_value, old_value
303+
), f"{f_out}, but should be {old_value} even after x was mutated"
304+
else:
305+
assert scalar_eq(
306+
new_out_value, value
307+
), f"{f_out}, but should be {value} after x was mutated"
285308

286309

287310
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))

0 commit comments

Comments
 (0)