Skip to content

MAINT: Check essential data functions #380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import count
from typing import Iterator, NamedTuple, Union

import pytest
from hypothesis import assume, given, note
from hypothesis import strategies as st

Expand Down Expand Up @@ -76,6 +77,7 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float]
)


@pytest.mark.has_setup_funcs
@given(dtype=st.none() | hh.real_dtypes, data=st.data())
def test_arange(dtype, data):
if dtype is None or dh.is_float_dtype(dtype):
Expand Down Expand Up @@ -194,6 +196,7 @@ def test_arange(dtype, data):
), f"out[0]={out[0]}, but should be {_start} {f_func}"


@pytest.mark.has_setup_funcs
@given(shape=hh.shapes(min_side=1), data=st.data())
def test_asarray_scalars(shape, data):
kw = data.draw(
Expand Down Expand Up @@ -257,6 +260,7 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
return s1 == s2


@pytest.mark.has_setup_funcs
@given(
shape=hh.shapes(),
dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
Expand Down Expand Up @@ -424,6 +428,7 @@ def test_full(shape, fill_value, kw):
ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value))


@pytest.mark.has_setup_funcs
@given(kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), data=st.data())
def test_full_like(kw, data):
dtype = kw.get("dtype", None) or data.draw(hh.all_dtypes, label="dtype")
Expand All @@ -442,6 +447,7 @@ def test_full_like(kw, data):
finite_kw = {"allow_nan": False, "allow_infinity": False}


@pytest.mark.has_setup_funcs
@given(
num=hh.sizes,
dtype=st.none() | hh.real_floating_dtypes,
Expand Down Expand Up @@ -492,6 +498,7 @@ def test_linspace(num, dtype, endpoint, data):
ph.assert_array_elements("linspace", out=out, expected=expected)


@pytest.mark.has_setup_funcs
@given(dtype=hh.numeric_dtypes, data=st.data())
def test_meshgrid(dtype, data):
# The number and size of generated arrays is arbitrarily limited to prevent
Expand Down
1 change: 1 addition & 0 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def test_repeat(x, kw, data):

reshape_shape = st.shared(hh.shapes(), key="reshape_shape")

@pytest.mark.has_setup_funcs
@pytest.mark.unvectorized
@given(
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
Expand Down
29 changes: 29 additions & 0 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def test_acosh(x):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
@given(data=st.data())
def test_add(ctx, data):
Expand Down Expand Up @@ -854,6 +855,7 @@ def test_atanh(x):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
)
Expand All @@ -873,6 +875,7 @@ def test_bitwise_and(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
)
Expand All @@ -895,6 +898,7 @@ def test_bitwise_left_shift(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes)
)
Expand All @@ -913,6 +917,7 @@ def test_bitwise_invert(ctx, data):
unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}")


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
)
Expand All @@ -932,6 +937,7 @@ def test_bitwise_or(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
)
Expand All @@ -953,6 +959,7 @@ def test_bitwise_right_shift(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
)
Expand Down Expand Up @@ -981,6 +988,7 @@ def test_ceil(x):


@pytest.mark.min_version("2023.12")
@pytest.mark.has_setup_funcs
@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data())
def test_clip(x, data):
# Ensure that if both min and max are arrays that all three of x, min, max
Expand Down Expand Up @@ -1145,6 +1153,7 @@ def test_cosh(x):
unary_assert_against_refimpl("cosh", x, out, refimpl)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes))
@given(data=st.data())
def test_divide(ctx, data):
Expand All @@ -1168,6 +1177,7 @@ def test_divide(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
@given(data=st.data())
def test_equal(ctx, data):
Expand Down Expand Up @@ -1242,6 +1252,7 @@ def refimpl(z):
unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes))
@given(data=st.data())
def test_floor_divide(ctx, data):
Expand All @@ -1261,6 +1272,7 @@ def test_floor_divide(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes))
@given(data=st.data())
def test_greater(ctx, data):
Expand All @@ -1281,6 +1293,7 @@ def test_greater(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes))
@given(data=st.data())
def test_greater_equal(ctx, data):
Expand Down Expand Up @@ -1352,6 +1365,7 @@ def test_isnan(x):
unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes))
@given(data=st.data())
def test_less(ctx, data):
Expand All @@ -1372,6 +1386,7 @@ def test_less(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes))
@given(data=st.data())
def test_less_equal(ctx, data):
Expand Down Expand Up @@ -1463,6 +1478,7 @@ def logaddexp_refimpl(l: float, r: float) -> float:


@pytest.mark.min_version("2023.12")
@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_logaddexp(x1, x2):
out = xp.logaddexp(x1, x2)
Expand All @@ -1476,6 +1492,7 @@ def test_logaddexp(x1, x2):
)


@pytest.mark.has_setup_funcs
@given(hh.arrays(dtype=xp.bool, shape=hh.shapes()))
def test_logical_not(x):
out = xp.logical_not(x)
Expand All @@ -1486,6 +1503,7 @@ def test_logical_not(x):
)


@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_and(x1, x2):
out = xp.logical_and(x1, x2)
Expand All @@ -1500,6 +1518,7 @@ def test_logical_and(x1, x2):
)


@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_or(x1, x2):
out = xp.logical_or(x1, x2)
Expand All @@ -1514,6 +1533,7 @@ def test_logical_or(x1, x2):
)


@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_xor(x1, x2):
out = xp.logical_xor(x1, x2)
Expand Down Expand Up @@ -1546,6 +1566,7 @@ def test_minimum(x1, x2):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@given(data=st.data())
def test_multiply(ctx, data):
Expand Down Expand Up @@ -1577,6 +1598,7 @@ def test_negative(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
@given(data=st.data())
def test_not_equal(ctx, data):
Expand All @@ -1598,6 +1620,7 @@ def test_not_equal(ctx, data):


@pytest.mark.min_version("2024.12")
@pytest.mark.has_setup_funcs
@given(
shapes=hh.two_mutually_broadcastable_shapes,
dtype=hh.real_floating_dtypes,
Expand All @@ -1617,6 +1640,8 @@ def test_nextafter(shapes, dtype, data):
out=out
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes))
@given(data=st.data())
def test_positive(ctx, data):
Expand All @@ -1629,6 +1654,7 @@ def test_positive(ctx, data):
ph.assert_array_elements(ctx.func_name, out=out, expected=x)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
@given(data=st.data())
def test_pow(ctx, data):
Expand Down Expand Up @@ -1676,6 +1702,7 @@ def test_reciprocal(x):


@pytest.mark.skip(reason="flaky")
@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes))
@given(data=st.data())
def test_remainder(ctx, data):
Expand Down Expand Up @@ -1770,6 +1797,7 @@ def test_sqrt(x):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
@given(data=st.data())
def test_subtract(ctx, data):
Expand Down Expand Up @@ -1923,6 +1951,7 @@ def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):


@pytest.mark.min_version("2024.12")
@pytest.mark.has_setup_funcs
@pytest.mark.unvectorized
@given(
x1x2=hh.array_and_py_scalar([xp.int32]),
Expand Down
1 change: 1 addition & 0 deletions array_api_tests/test_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_any(x, data):

@pytest.mark.unvectorized
@pytest.mark.min_version("2024.12")
@pytest.mark.has_setup_funcs
@given(
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
data=st.data(),
Expand Down
15 changes: 15 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def pytest_configure(config):
"markers",
"unvectorized: asserts against values via element-wise iteration (not performative!)",
)
config.addinivalue_line(
"markers",
"has_setup_funcs: run when essential draw data setup functions used "
"by Hypothesis are available in the namespace",
)
# Hypothesis
deadline = None if config.getoption("--hypothesis-disable-deadline") else 800
settings.register_profile(
Expand Down Expand Up @@ -202,6 +207,9 @@ def pytest_collection_modifyitems(config, items):
# ------------------------------------------------------

xfail_mark = get_xfail_mark()

essential_funcs = ["asarray", "isnan", "reshape", "zeros"]
HAS_ESSENTIAL_FUNCS = all(hasattr(xp, func_name) for func_name in essential_funcs)

for item in items:
markers = list(item.iter_markers())
Expand Down Expand Up @@ -245,6 +253,13 @@ def pytest_collection_modifyitems(config, items):
reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}"
)
)
# skip if namespace doesn't support essential draw data setup functions
if any(m.name == "has_setup_funcs" for m in markers) and not HAS_ESSENTIAL_FUNCS:
item.add_marker(
mark.skip(reason="At least one of the essential data setup "
"functions is not present in the namespace: "
f"{essential_funcs}")
)
# reduce max generated Hypothesis example for unvectorized tests
if any(m.name == "unvectorized" for m in markers):
# TODO: limit generated examples when settings already applied
Expand Down