Skip to content

Commit cc9f403

Browse files
authored
TST/BUG: run all tests on all backends; fix backend-specific bugs (#88)
1 parent 1708482 commit cc9f403

15 files changed

+529
-201
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ repos:
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
rev: "v0.8.2"
4646
hooks:
47+
- id: ruff-format
4748
- id: ruff
4849
args: ["--fix", "--show-fixes"]
49-
- id: ruff-format
5050

5151
- repo: https://github.com/codespell-project/codespell
5252
rev: "v2.3.0"

pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ xfail_strict = true
180180
filterwarnings = ["error"]
181181
log_cli_level = "INFO"
182182
testpaths = ["tests"]
183-
183+
markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]
184184

185185
# Coverage
186186

@@ -315,6 +315,7 @@ checks = [
315315
exclude = [ # don't report on objects that match any of these regex
316316
'.*test_at.*',
317317
'.*test_funcs.*',
318+
'.*test_testing.*',
318319
'.*test_utils.*',
319320
'.*test_version.*',
320321
'.*test_vendor.*',

src/array_api_extra/_funcs.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def create_diagonal(
214214
raise ValueError(err_msg)
215215
n = x.shape[0] + abs(offset)
216216
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
217-
i = offset if offset >= 0 else abs(offset) * n
218-
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
217+
218+
start = offset if offset >= 0 else abs(offset) * n
219+
stop = min(n * (n - offset), diag.shape[0])
220+
step = n + 1
221+
diag = at(diag)[start:stop:step].set(x)
222+
219223
return xp.reshape(diag, (n, n))
220224

221225

@@ -407,9 +411,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
407411
result = xp.multiply(a_arr, b_arr)
408412

409413
# Reshape back and return
410-
a_shape = xp.asarray(a_shape)
411-
b_shape = xp.asarray(b_shape)
412-
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
414+
res_shape = tuple(a_s * b_s for a_s, b_s in zip(a_shape, b_shape, strict=True))
415+
return xp.reshape(result, res_shape)
413416

414417

415418
def setdiff1d(
@@ -632,8 +635,7 @@ def pad(
632635
dtype=x.dtype,
633636
device=_compat.device(x),
634637
)
635-
padded[tuple(slices)] = x
636-
return padded
638+
return at(padded, tuple(slices)).set(x)
637639

638640

639641
class _AtOp(Enum):

src/array_api_extra/_lib/_compat.py

+15
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,35 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace,
88
device,
9+
is_cupy_namespace,
910
is_jax_array,
11+
is_jax_namespace,
12+
is_pydata_sparse_namespace,
13+
is_torch_namespace,
1014
is_writeable_array,
15+
size,
1116
)
1217
except ImportError:
1318
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1419
array_namespace,
1520
device,
21+
is_cupy_namespace,
1622
is_jax_array,
23+
is_jax_namespace,
24+
is_pydata_sparse_namespace,
25+
is_torch_namespace,
1726
is_writeable_array,
27+
size,
1828
)
1929

2030
__all__ = [
2131
"array_namespace",
2232
"device",
33+
"is_cupy_namespace",
2334
"is_jax_array",
35+
"is_jax_namespace",
36+
"is_pydata_sparse_namespace",
37+
"is_torch_namespace",
2438
"is_writeable_array",
39+
"size",
2540
]

src/array_api_extra/_lib/_compat.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,10 @@ def array_namespace(
1818
use_compat: bool | None = None,
1919
) -> ArrayModule: ...
2020
def device(x: Array, /) -> Device: ...
21+
def is_cupy_namespace(x: object, /) -> bool: ...
2122
def is_jax_array(x: object, /) -> bool: ...
23+
def is_jax_namespace(x: object, /) -> bool: ...
24+
def is_pydata_sparse_namespace(x: object, /) -> bool: ...
25+
def is_torch_namespace(x: object, /) -> bool: ...
2226
def is_writeable_array(x: object, /) -> bool: ...
27+
def size(x: Array, /) -> int | None: ...

src/array_api_extra/_lib/_testing.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Testing utilities.
3+
4+
Note that this is private API; don't expect it to be stable.
5+
"""
6+
7+
from ._compat import (
8+
array_namespace,
9+
is_cupy_namespace,
10+
is_pydata_sparse_namespace,
11+
is_torch_namespace,
12+
)
13+
from ._typing import Array, ModuleType
14+
15+
__all__ = ["xp_assert_close", "xp_assert_equal"]
16+
17+
18+
def _check_ns_shape_dtype(
19+
actual: Array, desired: Array
20+
) -> ModuleType: # numpydoc ignore=RT03
21+
"""
22+
Assert that namespace, shape and dtype of the two arrays match.
23+
24+
Parameters
25+
----------
26+
actual : Array
27+
The array produced by the tested function.
28+
desired : Array
29+
The expected array (typically hardcoded).
30+
31+
Returns
32+
-------
33+
Arrays namespace.
34+
"""
35+
actual_xp = array_namespace(actual) # Raises on scalars and lists
36+
desired_xp = array_namespace(desired)
37+
38+
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
39+
assert actual_xp == desired_xp, msg
40+
41+
msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
42+
assert actual.shape == desired.shape, msg
43+
44+
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
45+
assert actual.dtype == desired.dtype, msg
46+
47+
return desired_xp
48+
49+
50+
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
51+
"""
52+
Array-API compatible version of `np.testing.assert_array_equal`.
53+
54+
Parameters
55+
----------
56+
actual : Array
57+
The array produced by the tested function.
58+
desired : Array
59+
The expected array (typically hardcoded).
60+
err_msg : str, optional
61+
Error message to display on failure.
62+
"""
63+
xp = _check_ns_shape_dtype(actual, desired)
64+
65+
if is_cupy_namespace(xp):
66+
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
67+
elif is_torch_namespace(xp):
68+
# PyTorch recommends using `rtol=0, atol=0` like this
69+
# to test for exact equality
70+
xp.testing.assert_close(
71+
actual,
72+
desired,
73+
rtol=0,
74+
atol=0,
75+
equal_nan=True,
76+
check_dtype=False,
77+
msg=err_msg or None,
78+
)
79+
else:
80+
import numpy as np # pylint: disable=import-outside-toplevel
81+
82+
if is_pydata_sparse_namespace(xp):
83+
actual = actual.todense()
84+
desired = desired.todense()
85+
86+
# JAX uses `np.testing`
87+
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
88+
89+
90+
def xp_assert_close(
91+
actual: Array,
92+
desired: Array,
93+
*,
94+
rtol: float | None = None,
95+
atol: float = 0,
96+
err_msg: str = "",
97+
) -> None:
98+
"""
99+
Array-API compatible version of `np.testing.assert_allclose`.
100+
101+
Parameters
102+
----------
103+
actual : Array
104+
The array produced by the tested function.
105+
desired : Array
106+
The expected array (typically hardcoded).
107+
rtol : float, optional
108+
Relative tolerance. Default: dtype-dependent.
109+
atol : float, optional
110+
Absolute tolerance. Default: 0.
111+
err_msg : str, optional
112+
Error message to display on failure.
113+
"""
114+
xp = _check_ns_shape_dtype(actual, desired)
115+
116+
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
117+
if rtol is None and floating:
118+
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
119+
# roughly half way between sqrt(eps) and the default for
120+
# `numpy.testing.assert_allclose`, 1e-7
121+
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
122+
elif rtol is None:
123+
rtol = 1e-7
124+
125+
if is_cupy_namespace(xp):
126+
xp.testing.assert_allclose(
127+
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
128+
)
129+
elif is_torch_namespace(xp):
130+
xp.testing.assert_close(
131+
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
132+
)
133+
else:
134+
import numpy as np # pylint: disable=import-outside-toplevel
135+
136+
if is_pydata_sparse_namespace(xp):
137+
actual = actual.to_dense()
138+
desired = desired.to_dense()
139+
140+
# JAX uses `np.testing`
141+
assert isinstance(rtol, float)
142+
np.testing.assert_allclose(
143+
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
144+
)

src/array_api_extra/_lib/_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def in1d(
5454
order = xp.argsort(ar, stable=True)
5555
reverse_order = xp.argsort(order, stable=True)
5656
sar = xp.take(ar, order, axis=0)
57-
if sar.size >= 1:
57+
ar_size = _compat.size(sar)
58+
assert ar_size is not None, "xp.unique*() on lazy backends raises"
59+
if ar_size >= 1:
5860
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
5961
else:
6062
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])

tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Needed to import .conftest from the test modules."""

tests/conftest.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Pytest fixtures."""
2+
3+
from enum import Enum
4+
from typing import cast
5+
6+
import pytest
7+
8+
from array_api_extra._lib._compat import array_namespace
9+
from array_api_extra._lib._compat import device as get_device
10+
from array_api_extra._lib._typing import Device, ModuleType
11+
12+
13+
class Library(Enum):
14+
"""All array libraries explicitly tested by array-api-extra."""
15+
16+
ARRAY_API_STRICT = "array_api_strict"
17+
NUMPY = "numpy"
18+
NUMPY_READONLY = "numpy_readonly"
19+
CUPY = "cupy"
20+
TORCH = "torch"
21+
DASK_ARRAY = "dask.array"
22+
SPARSE = "sparse"
23+
JAX_NUMPY = "jax.numpy"
24+
25+
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
26+
"""Pretty-print parameterized test names."""
27+
return self.value
28+
29+
30+
@pytest.fixture(params=tuple(Library))
31+
def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01,RT03
32+
"""
33+
Parameterized fixture that iterates on all libraries.
34+
35+
Returns
36+
-------
37+
The current Library enum.
38+
"""
39+
elem = cast(Library, request.param)
40+
41+
for marker in request.node.iter_markers("skip_xp_backend"):
42+
skip_library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
43+
if not isinstance(skip_library, Library):
44+
msg = "argument of skip_xp_backend must be a Library enum"
45+
raise TypeError(msg)
46+
if skip_library == elem:
47+
reason = cast(str, marker.kwargs.get("reason", "skip_xp_backend"))
48+
pytest.skip(reason=reason)
49+
50+
return elem
51+
52+
53+
@pytest.fixture
54+
def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03
55+
"""
56+
Parameterized fixture that iterates on all libraries.
57+
58+
Returns
59+
-------
60+
The current array namespace.
61+
"""
62+
name = "numpy" if library == Library.NUMPY_READONLY else library.value
63+
xp = pytest.importorskip(name)
64+
if library == Library.JAX_NUMPY:
65+
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
66+
67+
jax.config.update("jax_enable_x64", True)
68+
69+
# Possibly wrap module with array_api_compat
70+
return array_namespace(xp.empty(0))
71+
72+
73+
@pytest.fixture
74+
def device(
75+
library: Library, xp: ModuleType
76+
) -> Device: # numpydoc ignore=PR01,RT01,RT03
77+
"""
78+
Return a valid device for the backend.
79+
80+
Where possible, return a device that is not the default one.
81+
"""
82+
if library == Library.ARRAY_API_STRICT:
83+
d = xp.Device("device1")
84+
assert get_device(xp.empty(0)) != d
85+
return d
86+
return get_device(xp.empty(0))

0 commit comments

Comments
 (0)