Skip to content

Commit e5dd419

Browse files
authored
Merge pull request #231 from crusaderky/test_size
ENH: size() to return None on dask instead of nan
2 parents beac55b + d947529 commit e5dd419

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

array_api_compat/common/_helpers.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -788,19 +788,24 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788788
return x.to_device(device, stream=stream)
789789

790790

791-
def size(x):
791+
def size(x: Array) -> int | None:
792792
"""
793793
Return the total number of elements of x.
794794
795795
This is equivalent to `x.size` according to the `standard
796796
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
797+
797798
This helper is included because PyTorch defines `size` in an
798799
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
799-
800+
It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
801+
the standard requires None.
800802
"""
803+
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
801804
if None in x.shape:
802805
return None
803-
return math.prod(x.shape)
806+
out = math.prod(x.shape)
807+
# dask.array.Array.shape can contain NaN
808+
return None if math.isnan(out) else out
804809

805810

806811
def is_writeable_array(x) -> bool:

tests/test_common.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
66
)
77

8-
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
9-
8+
from array_api_compat import (
9+
device, is_array_api_obj, is_writeable_array, size, to_device
10+
)
1011
from ._helpers import import_, wrapped_libraries, all_libraries
1112

1213
import pytest
@@ -92,6 +93,28 @@ def test_is_writeable_array_numpy():
9293
assert not is_writeable_array(x)
9394

9495

96+
@pytest.mark.parametrize("library", all_libraries)
97+
def test_size(library):
98+
xp = import_(library)
99+
x = xp.asarray([1, 2, 3])
100+
assert size(x) == 3
101+
102+
103+
@pytest.mark.parametrize("library", all_libraries)
104+
def test_size_none(library):
105+
if library == "sparse":
106+
pytest.skip("No arange(); no indexing by sparse arrays")
107+
108+
xp = import_(library)
109+
x = xp.arange(10)
110+
x = x[x < 5]
111+
112+
# dask.array now has shape=(nan, ) and size=nan
113+
# ndonnx now has shape=(None, ) and size=None
114+
# Eager libraries have shape=(5, ) and size=5
115+
assert size(x) in (None, 5)
116+
117+
95118
@pytest.mark.parametrize("library", all_libraries)
96119
def test_device(library):
97120
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)