Skip to content

ENH: support PyTorch device='meta' #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

if not capabilities(xp)["boolean indexing"]:
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)

Expand Down Expand Up @@ -716,7 +716,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
# 2. backend has unique_counts and it returns a None-sized array;
# e.g. Dask, ndonnx
# 3. backend does not have unique_counts; e.g. wrapped JAX
if capabilities(xp)["data-dependent shapes"]:
if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]:
# xp has unique_counts; O(n) complexity
_, counts = xp.unique_counts(x)
n = _compat.size(counts)
Expand Down
6 changes: 5 additions & 1 deletion src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

if is_torch_namespace(xp):
array = to_device(array, "cpu")
if array.device.type == "meta": # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
# Can't materialize; generate dummy data instead
array = xp.zeros_like(array, device="cpu")
else:
array = to_device(array, "cpu")
if is_array_api_strict_namespace(xp):
cpu: Device = xp.Device("CPU_DEVICE")
array = to_device(array, cpu)
Expand Down
14 changes: 12 additions & 2 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
is_jax_namespace,
is_numpy_array,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._typing import Array
from ._typing import Array, Device

if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.12 and >=3.13)
Expand Down Expand Up @@ -300,7 +301,7 @@ def meta_namespace(
return array_namespace(*metas)


def capabilities(xp: ModuleType) -> dict[str, int]:
def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]:
"""
Return patched ``xp.__array_namespace_info__().capabilities()``.

Expand All @@ -311,6 +312,8 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
----------
xp : array_namespace
The standard-compatible namespace.
device : Device, optional
The device to use.

Returns
-------
Expand All @@ -326,6 +329,13 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
# Fixed in jax >=0.6.0
out = out.copy()
out["boolean indexing"] = False
if is_torch_namespace(xp):
# FIXME https://github.com/data-apis/array-api/issues/945
device = xp.get_default_device() if device is None else xp.device(device)
if cast(Any, device).type == "meta": # type: ignore[explicit-any]
out = out.copy()
out["boolean indexing"] = False
out["data-dependent shapes"] = False
return out


Expand Down
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def device(
Where possible, return a device that is not the default one.
"""
if library == Backend.ARRAY_API_STRICT:
d = xp.Device("device1")
assert get_device(xp.empty(0)) != d
return d
return xp.Device("device1")
if library == Backend.TORCH:
return xp.device("meta")
if library == Backend.TORCH_GPU:
return xp.device("cpu")
return get_device(xp.empty(0))
6 changes: 3 additions & 3 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,6 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
res = isclose(a, b, equal_nan=equal_nan)
assert get_device(res) == device
xp_assert_equal(
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
)


class TestKron:
Expand Down Expand Up @@ -996,6 +993,9 @@ def test_all_python_scalars(self, assume_unique: bool):
_ = setdiff1d(0, 0, assume_unique=assume_unique)

@assume_unique
@pytest.mark.skip_xp_backend(
Backend.TORCH, reason="device='meta' does not support unknown shapes"
)
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
Expand Down
30 changes: 25 additions & 5 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType):
assert meta_namespace(*args, xp=xp) in (xp, np_compat)


def test_capabilities(xp: ModuleType):
expect = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect.add("max dimensions")
assert capabilities(xp).keys() == expect
class TestCapabilities:
def test_basic(self, xp: ModuleType):
expect = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect.add("max dimensions")
assert capabilities(xp).keys() == expect

def test_device(self, xp: ModuleType, library: Backend, device: Device):
expect_keys = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect_keys.add("max dimensions")
assert capabilities(xp, device=device).keys() == expect_keys

if library.like(Backend.TORCH):
# The output of capabilities is device-specific.

# Test that device=None gets the current default device.
expect = capabilities(xp, device=device)
with xp.device(device):
actual = capabilities(xp)
assert actual == expect

# Test that we're accepting anything that is accepted by the
# device= parameter in other functions
actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]


class Wrapper(Generic[T]):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
Backend.ARRAY_API_STRICT, reason="device->host copy"
),
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
pytest.mark.skip_xp_backend(
Backend.TORCH, reason="materialize 'meta' device"
),
pytest.mark.skip_xp_backend(
Backend.TORCH_GPU, reason="device->host copy"
),
Expand Down
20 changes: 16 additions & 4 deletions tests/test_testing.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to add an explicit test for xp_assert_equal etc. but it's blocked by #301

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now merged

Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,22 @@
)


def test_as_numpy_array(xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
y = as_numpy_array(x, xp=xp)
assert isinstance(y, np.ndarray)
class TestAsNumPyArray:
def test_basic(self, xp: ModuleType):
x = xp.asarray([1, 2, 3])
y = as_numpy_array(x, xp=xp)
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

def test_device(self, xp: ModuleType, library: Backend, device: Device):
x = xp.asarray([1, 2, 3], device=device)
actual = as_numpy_array(x, xp=xp)
if library is Backend.TORCH:
assert device.type == "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
expect = np.asarray([0, 0, 0])
else:
expect = np.asarray([1, 2, 3])

xp_assert_equal(actual, expect) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
Expand Down