Skip to content

Commit ae0bf1f

Browse files
committed
Fix xp_array_less
1 parent 90c3aec commit ae0bf1f

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

src/array_api_extra/_lib/_testing.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
is_jax_namespace,
2323
is_numpy_namespace,
2424
is_pydata_sparse_namespace,
25+
is_torch_array,
2526
is_torch_namespace,
2627
to_device,
2728
)
@@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
6263
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
6364
assert actual_xp == desired_xp, msg
6465

65-
if check_shape:
66-
actual_shape = actual.shape
67-
desired_shape = desired.shape
68-
if is_dask_namespace(desired_xp):
69-
# Dask uses nan instead of None for unknown shapes
70-
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
71-
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72-
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
73-
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
66+
# Dask uses nan instead of None for unknown shapes
67+
actual_shape = cast(tuple[float, ...], actual.shape)
68+
desired_shape = cast(tuple[float, ...], desired.shape)
69+
assert None not in actual_shape # Requires explicit support
70+
assert None not in desired_shape
71+
if is_dask_namespace(desired_xp):
72+
if any(math.isnan(i) for i in actual_shape):
73+
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74+
if any(math.isnan(i) for i in desired_shape):
75+
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
7476

77+
if check_shape:
7578
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
7679
assert actual_shape == desired_shape, msg
80+
else:
81+
# Ignore shape, but check flattened size. This is normally done by
82+
# np.testing.assert_array_equal etc even when strict=False, but not for
83+
# non-materializable arrays.
84+
actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType]
85+
desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType]
86+
msg = f"sizes do not match: {actual_size} != f{desired_size}"
87+
assert actual_size == desired_size, msg
7788

7889
if check_dtype:
7990
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
@@ -90,6 +101,17 @@ def _check_ns_shape_dtype(
90101
return desired_xp
91102

92103

104+
def _is_materializable(x: Array) -> bool:
105+
"""
106+
Check if the array is materializable, e.g. `as_numpy_array` can be called on it
107+
and one can assume that `__dlpack__` will succeed (if implemented, and given a
108+
compatible device).
109+
"""
110+
# Important: here we assume that we're not tracing -
111+
# e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
112+
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
113+
114+
93115
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
94116
"""
95117
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
@@ -100,11 +122,7 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
100122
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
101123

102124
if is_torch_namespace(xp):
103-
if array.device.type == "meta": # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
104-
# Can't materialize; generate dummy data instead
105-
array = xp.zeros_like(array, device="cpu")
106-
else:
107-
array = to_device(array, "cpu")
125+
array = to_device(array, "cpu")
108126
if is_array_api_strict_namespace(xp):
109127
cpu: Device = xp.Device("CPU_DEVICE")
110128
array = to_device(array, cpu)
@@ -150,6 +168,8 @@ def xp_assert_equal(
150168
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
151169
"""
152170
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
171+
if not _is_materializable(actual):
172+
return
153173
actual_np = as_numpy_array(actual, xp=xp)
154174
desired_np = as_numpy_array(desired, xp=xp)
155175
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)
@@ -185,6 +205,8 @@ def xp_assert_less(
185205
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
186206
"""
187207
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
208+
if not _is_materializable(x):
209+
return
188210
x_np = as_numpy_array(x, xp=xp)
189211
y_np = as_numpy_array(y, xp=xp)
190212
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)
@@ -233,6 +255,8 @@ def xp_assert_close(
233255
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
234256
"""
235257
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
258+
if not _is_materializable(actual):
259+
return
236260

237261
if rtol is None:
238262
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):

tests/test_testing.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,11 @@ def test_basic(self, xp: ModuleType):
3030
y = as_numpy_array(x, xp=xp)
3131
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
3232

33-
def test_device(self, xp: ModuleType, library: Backend, device: Device):
33+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
34+
def test_device(self, xp: ModuleType, device: Device):
3435
x = xp.asarray([1, 2, 3], device=device)
35-
actual = as_numpy_array(x, xp=xp)
36-
if library is Backend.TORCH:
37-
assert device.type == "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
38-
expect = np.asarray([0, 0, 0])
39-
else:
40-
expect = np.asarray([1, 2, 3])
41-
42-
xp_assert_equal(actual, expect) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
36+
y = as_numpy_array(x, xp=xp)
37+
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
4338

4439

4540
class TestAssertEqualCloseLess:
@@ -92,7 +87,7 @@ def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
9287
func(a, b, check_shape=False)
9388
with pytest.raises(AssertionError, match="Mismatched elements"):
9489
func(a, c, check_shape=False)
95-
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
90+
with pytest.raises(AssertionError, match="sizes do not match"):
9691
func(a, d, check_shape=False)
9792

9893
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@@ -181,6 +176,20 @@ def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
181176
with pytest.raises(AssertionError, match="Mismatched elements"):
182177
func(xp.asarray([4]), a)
183178

179+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
180+
def test_device(self, xp: ModuleType, device: Device, func: Callable[..., None]):
181+
a = xp.asarray([1] if func is xp_assert_less else [2], device=device)
182+
b = xp.asarray([2], device=device)
183+
c = xp.asarray([2, 2], device=device)
184+
185+
func(a, b)
186+
with pytest.raises(AssertionError, match="shapes do not match"):
187+
func(a, c)
188+
# This is normally performed by np.testing.assert_array_equal etc.
189+
# but in case of torch device='meta' we have to do it manually
190+
with pytest.raises(AssertionError, match="sizes do not match"):
191+
func(a, c, check_shape=False)
192+
184193

185194
def good_lazy(x: Array) -> Array:
186195
"""A function that behaves well in Dask and jax.jit"""

0 commit comments

Comments
 (0)