Skip to content

Support skipping dtypes by setting ARRAY_API_TESTS_SKIP_DTYPES #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
select = F
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This is the test suite for array libraries adopting the [Python Array API
standard](https://data-apis.org/array-api/latest).

Keeping full coverage of the spec is an on-going priority as the Array API evolves.
Keeping full coverage of the spec is an on-going priority as the Array API evolves.
Feedback and contributions are welcome!

## Quickstart
Expand Down Expand Up @@ -285,6 +285,19 @@ values should result in more rigorous runs. For example, `--max-examples
10_000` may find bugs where default runs don't but will take much longer to
run.

#### Skipping Dtypes

The test suite will automatically skip testing of inessential dtypes if they
are not present on the array module namespace, but dtypes can also be skipped
manually by setting the environment variable `ARRAY_API_TESTS_SKIP_DTYPES` to
a comma separated list of dtypes to skip. For example

```
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 pytest array_api_tests/
```

Note that skipping certain essential dtypes such as `bool` and the default
floating-point dtype is not supported.

## Contributing

Expand Down
36 changes: 20 additions & 16 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from collections import defaultdict
from collections.abc import Mapping
Expand Down Expand Up @@ -104,9 +105,18 @@ def __repr__(self):
numeric_names = real_names + complex_names
dtype_names = ("bool",) + numeric_names

_skip_dtypes = os.getenv("ARRAY_API_TESTS_SKIP_DTYPES", '')
_skip_dtypes = _skip_dtypes.split(',')
skip_dtypes = []
for dtype in _skip_dtypes:
if dtype and dtype not in dtype_names:
raise ValueError(f"Invalid dtype name in ARRAY_API_TESTS_SKIP_DTYPES: {dtype}")
skip_dtypes.append(dtype)

_name_to_dtype = {}
for name in dtype_names:
if name in skip_dtypes:
continue
try:
dtype = getattr(xp, name)
except AttributeError:
Expand Down Expand Up @@ -184,9 +194,9 @@ def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
dtype_value_pairs = []
for name, value in mapping.items():
assert isinstance(name, str) and name in dtype_names # sanity check
try:
dtype = getattr(xp, name)
except AttributeError:
if name in _name_to_dtype:
dtype = _name_to_dtype[name]
else:
continue
dtype_value_pairs.append((dtype, value))
return EqualityMapping(dtype_value_pairs)
Expand Down Expand Up @@ -313,9 +323,9 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
else:
default_complex = None
if dtype_nbits[default_int] == 32:
default_uint = getattr(xp, "uint32", None)
default_uint = _name_to_dtype.get("uint32")
else:
default_uint = getattr(xp, "uint64", None)
default_uint = _name_to_dtype.get("uint64")

_promotion_table: Dict[Tuple[str, str], str] = {
("bool", "bool"): "bool",
Expand Down Expand Up @@ -366,18 +376,12 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()})
_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = []
for (in_name1, in_name2), res_name in _promotion_table.items():
try:
in_dtype1 = getattr(xp, in_name1)
except AttributeError:
continue
try:
in_dtype2 = getattr(xp, in_name2)
except AttributeError:
continue
try:
res_dtype = getattr(xp, res_name)
except AttributeError:
if in_name1 not in _name_to_dtype or in_name2 not in _name_to_dtype or res_name not in _name_to_dtype:
continue
in_dtype1 = _name_to_dtype[in_name1]
in_dtype2 = _name_to_dtype[in_name2]
res_dtype = _name_to_dtype[res_name]

_promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype))
promotion_table = EqualityMapping(_promotion_table_pairs)

Expand Down
34 changes: 24 additions & 10 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,24 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
return OnewayBroadcastableShapes(input_shape, result_shape)


# Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
# ARRAY_API_TESTS_SKIP_DTYPES
all_dtypes = sampled_from(_sorted_dtypes)
int_dtypes = sampled_from(dh.int_dtypes)
uint_dtypes = sampled_from(dh.uint_dtypes)
real_dtypes = sampled_from(dh.real_dtypes)
# Warning: The hypothesis "floating_dtypes" is what we call
# "real_floating_dtypes"
floating_dtypes = sampled_from(dh.all_float_dtypes)
real_floating_dtypes = sampled_from(dh.real_float_dtypes)
numeric_dtypes = sampled_from(dh.numeric_dtypes)
# Note: this always returns complex dtypes, even if api_version < 2022.12
complex_dtypes = sampled_from(dh.complex_dtypes)

def all_floating_dtypes() -> SearchStrategy[DataType]:
strat = xps.floating_dtypes()
strat = floating_dtypes
if api_version >= "2022.12":
strat |= xps.complex_dtypes()
strat |= complex_dtypes
return strat


Expand Down Expand Up @@ -236,7 +250,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):

@composite
def finite_matrices(draw, shape=matrix_shapes()):
return draw(arrays(dtype=xps.floating_dtypes(),
return draw(arrays(dtype=floating_dtypes,
shape=shape,
elements=dict(allow_nan=False,
allow_infinity=False)))
Expand All @@ -245,7 +259,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
# Should we set a max_value here?
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
rtols = one_of(floats(**_rtol_float_kw),
arrays(dtype=xps.floating_dtypes(),
arrays(dtype=real_floating_dtypes,
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
elements=_rtol_float_kw))

Expand Down Expand Up @@ -280,9 +294,9 @@ def mutually_broadcastable_shapes(

two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2)

# Note: This should become hermitian_matrices when complex dtypes are added
# TODO: Add support for complex Hermitian matrices
@composite
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.):
def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.):
shape = draw(square_matrix_shapes)
dtype = draw(dtypes)
if not isinstance(finite, bool):
Expand All @@ -297,7 +311,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10
return H

@composite
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
def positive_definite_matrices(draw, dtypes=floating_dtypes):
# For now just generate stacks of identity matrices
# TODO: Generate arbitrary positive definite matrices, for instance, by
# using something like
Expand All @@ -310,7 +324,7 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
return broadcast_to(eye(n, dtype=dtype), shape)

@composite
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()):
def invertible_matrices(draw, dtypes=floating_dtypes, stack_shapes=shapes()):
# For now, just generate stacks of diagonal matrices.
stack_shape = draw(stack_shapes)
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),)
Expand Down Expand Up @@ -344,7 +358,7 @@ def two_broadcastable_shapes(draw):
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)

numeric_arrays = arrays(
dtype=shared(xps.floating_dtypes(), key='dtypes'),
dtype=shared(floating_dtypes, key='dtypes'),
shape=shared(xps.array_shapes(), key='shapes'),
)

Expand Down Expand Up @@ -388,7 +402,7 @@ def python_integer_indices(draw, sizes):
def integer_indices(draw, sizes):
# Return either a Python integer or a 0-D array with some integer dtype
idx = draw(python_integer_indices(sizes))
dtype = draw(xps.integer_dtypes() | xps.unsigned_integer_dtypes())
dtype = draw(int_dtypes | uint_dtypes)
m, M = dh.dtype_ranges[dtype]
if m <= idx <= M:
return draw(one_of(just(idx),
Expand Down
28 changes: 28 additions & 0 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,34 @@ def assert_dtype(
assert out_dtype == expected, msg


def assert_float_to_complex_dtype(
func_name: str, *, in_dtype: DataType, out_dtype: DataType
):
if in_dtype == xp.float32:
expected = xp.complex64
else:
assert in_dtype == xp.float64 # sanity check
expected = xp.complex128
assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
)


def assert_complex_to_float_dtype(
func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype"
):
if in_dtype == xp.complex64:
expected = xp.float32
elif in_dtype == xp.complex128:
expected = xp.float64
else:
assert in_dtype in (xp.float32, xp.float64) # sanity check
expected = in_dtype
assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name
)


def assert_kw_dtype(
func_name: str,
*,
Expand Down
23 changes: 9 additions & 14 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from . import xp as _xp
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape


Expand Down Expand Up @@ -75,7 +74,7 @@ def get_indexed_axes_and_out_shape(
return tuple(axes_indices), tuple(out_shape)


@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
@given(shape=hh.shapes(), dtype=hh.all_dtypes, data=st.data())
def test_getitem(shape, dtype, data):
zero_sided = any(side == 0 for side in shape)
if zero_sided:
Expand Down Expand Up @@ -157,7 +156,7 @@ def test_setitem(shape, dtypes, data):
@pytest.mark.data_dependent_shapes
@given(hh.shapes(), st.data())
def test_getitem_masking(shape, data):
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
mask_shapes = st.one_of(
st.sampled_from([x.shape, ()]),
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
Expand Down Expand Up @@ -202,7 +201,7 @@ def test_getitem_masking(shape, data):
@pytest.mark.unvectorized
@given(hh.shapes(), st.data())
def test_setitem_masking(shape, data):
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
value = data.draw(
hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
Expand Down Expand Up @@ -252,18 +251,14 @@ def make_scalar_casting_param(


@pytest.mark.parametrize(
"method_name, dtype_name, stype",
[make_scalar_casting_param("__bool__", "bool", bool)]
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names]
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names]
+ [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_names],
"method_name, dtype, stype",
[make_scalar_casting_param("__bool__", xp.bool, bool)]
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_dtypes]
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_dtypes]
+ [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_dtypes],
)
@given(data=st.data())
def test_scalar_casting(method_name, dtype_name, stype, data):
try:
dtype = getattr(_xp, dtype_name)
except AttributeError as e:
pytest.skip(str(e))
def test_scalar_casting(method_name, dtype, stype, data):
x = data.draw(hh.arrays(dtype, shape=()), label="x")
method = getattr(x, method_name)
out = method()
Expand Down
Loading
Loading