Skip to content

Commit cdd1c8d

Browse files
authored
Merge pull request #205 from crusaderky/jax
ENH: Test for read-only arrays
2 parents 1c7b1ba + c28c40d commit cdd1c8d

File tree

4 files changed

+42
-4
lines changed

4 files changed

+42
-4
lines changed

array_api_compat/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
NumPy Array API compatibility library
33
4-
This is a small wrapper around NumPy and CuPy that is compatible with the
5-
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6-
https://numpy.org/neps/nep-0047-array-api-standard.html.
4+
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are
5+
compatible with the Array API standard https://data-apis.org/array-api/latest/.
6+
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
77
88
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with

array_api_compat/common/_helpers.py

+19
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
787787
return x
788788
return x.to_device(device, stream=stream)
789789

790+
790791
def size(x):
791792
"""
792793
Return the total number of elements of x.
@@ -801,6 +802,23 @@ def size(x):
801802
return None
802803
return math.prod(x.shape)
803804

805+
806+
def is_writeable_array(x) -> bool:
807+
"""
808+
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
809+
810+
Warning
811+
-------
812+
As there is no standard way to check if an array is writeable without actually
813+
writing to it, this function blindly returns True for all unknown array types.
814+
"""
815+
if is_numpy_array(x):
816+
return x.flags.writeable
817+
if is_jax_array(x) or is_pydata_sparse_array(x):
818+
return False
819+
return True
820+
821+
804822
__all__ = [
805823
"array_namespace",
806824
"device",
@@ -821,6 +839,7 @@ def size(x):
821839
"is_ndonnx_namespace",
822840
"is_pydata_sparse_array",
823841
"is_pydata_sparse_namespace",
842+
"is_writeable_array",
824843
"size",
825844
"to_device",
826845
]

docs/helper-functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ yet.
5151
.. autofunction:: is_jax_array
5252
.. autofunction:: is_pydata_sparse_array
5353
.. autofunction:: is_ndonnx_array
54+
.. autofunction:: is_writeable_array
5455
.. autofunction:: is_numpy_namespace
5556
.. autofunction:: is_cupy_namespace
5657
.. autofunction:: is_torch_namespace

tests/test_common.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
66
)
77

8-
from array_api_compat import is_array_api_obj, device, to_device
8+
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
99

1010
from ._helpers import import_, wrapped_libraries, all_libraries
1111

@@ -74,6 +74,24 @@ def test_xp_is_array_generics(library):
7474
assert matches in ([library], ["numpy"])
7575

7676

77+
@pytest.mark.parametrize("library", all_libraries)
78+
def test_is_writeable_array(library):
79+
lib = import_(library)
80+
x = lib.asarray([1, 2, 3])
81+
if is_writeable_array(x):
82+
x[1] = 4
83+
else:
84+
with pytest.raises((TypeError, ValueError)):
85+
x[1] = 4
86+
87+
88+
def test_is_writeable_array_numpy():
89+
x = np.asarray([1, 2, 3])
90+
assert is_writeable_array(x)
91+
x.flags.writeable = False
92+
assert not is_writeable_array(x)
93+
94+
7795
@pytest.mark.parametrize("library", all_libraries)
7896
def test_device(library):
7997
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)