Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4a0943a

Browse files
committedDec 10, 2024·
WIP at() method
1 parent 343ebf4 commit 4a0943a

File tree

8 files changed

+481
-8
lines changed

8 files changed

+481
-8
lines changed
 

‎docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

‎pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ ignore = [
226226
"ISC001", # Conflicts with formatter
227227
"N802", # Function name should be lowercase
228228
"N806", # Variable in function should be lowercase
229+
"PD008", # pandas-use-of-dot-at
229230
]
230231

231232
[tool.ruff.lint.per-file-ignores]

‎src/array_api_extra/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3.dev0"
615

716
__all__ = [
817
"__version__",
18+
"at",
919
"atleast_nd",
1020
"cov",
1121
"create_diagonal",

‎src/array_api_extra/_funcs.py

Lines changed: 286 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from __future__ import annotations
22

3+
import operator
34
import warnings
5+
from collections.abc import Callable
6+
from typing import Any
47

58
from ._lib import _utils
6-
from ._lib._compat import array_namespace
9+
from ._lib._compat import (
10+
array_namespace,
11+
is_array_api_obj,
12+
is_dask_array,
13+
is_writeable_array,
14+
)
715
from ._lib._typing import Array, ModuleType
816

917
__all__ = [
18+
"at",
1019
"atleast_nd",
1120
"cov",
1221
"create_diagonal",
@@ -545,3 +554,279 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
545554
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
546555
)
547556
return xp.sin(y) / y
557+
558+
559+
_undef = object()
560+
561+
562+
class at: # noqa: N801
563+
"""
564+
Update operations for read-only arrays.
565+
566+
This implements ``jax.numpy.ndarray.at`` for all backends.
567+
568+
Parameters
569+
----------
570+
x : array
571+
Input array.
572+
idx : index, optional
573+
You may use two alternate syntaxes::
574+
575+
at(x, idx).set(value) # or get(), add(), etc.
576+
at(x)[idx].set(value)
577+
578+
copy : bool, optional
579+
True (default)
580+
Ensure that the inputs are not modified.
581+
False
582+
Ensure that the update operation writes back to the input.
583+
Raise ValueError if a copy cannot be avoided.
584+
None
585+
The array parameter *may* be modified in place if it is possible and
586+
beneficial for performance.
587+
You should not reuse it after calling this function.
588+
xp : array_namespace, optional
589+
The standard-compatible namespace for `x`. Default: infer
590+
591+
**kwargs:
592+
If the backend supports an `at` method, any additional keyword
593+
arguments are passed to it verbatim; e.g. this allows passing
594+
``indices_are_sorted=True`` to JAX.
595+
596+
Returns
597+
-------
598+
Updated input array.
599+
600+
Examples
601+
--------
602+
Given either of these equivalent expressions::
603+
604+
x = at(x)[1].add(2, copy=None)
605+
x = at(x, 1).add(2, copy=None)
606+
607+
If x is a JAX array, they are the same as::
608+
609+
x = x.at[1].add(2)
610+
611+
If x is a read-only numpy array, they are the same as::
612+
613+
x = x.copy()
614+
x[1] += 2
615+
616+
Otherwise, they are the same as::
617+
618+
x[1] += 2
619+
620+
Warning
621+
-------
622+
When you use copy=None, you should always immediately overwrite
623+
the parameter array::
624+
625+
x = at(x, 0).set(2, copy=None)
626+
627+
The anti-pattern below must be avoided, as it will result in different behaviour
628+
on read-only versus writeable arrays::
629+
630+
x = xp.asarray([0, 0, 0])
631+
y = at(x, 0).set(2, copy=None)
632+
z = at(x, 1).set(3, copy=None)
633+
634+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
635+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
636+
637+
Warning
638+
-------
639+
The array API standard does not support integer array indices.
640+
The behaviour of update methods when the index is an array of integers
641+
is undefined; this is particularly true when the index contains multiple
642+
occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``.
643+
644+
Note
645+
----
646+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
647+
648+
See Also
649+
--------
650+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
651+
"""
652+
653+
x: Array
654+
idx: Any
655+
__slots__ = ("idx", "x")
656+
657+
def __init__(self, x: Array, idx: Any = _undef, /):
658+
self.x = x
659+
self.idx = idx
660+
661+
def __getitem__(self, idx: Any) -> Any:
662+
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
663+
which looks prettier than ``at(x, slice(start, stop, step))``
664+
and feels more intuitive coming from the JAX documentation.
665+
"""
666+
if self.idx is not _undef:
667+
msg = "Index has already been set"
668+
raise ValueError(msg)
669+
self.idx = idx
670+
return self
671+
672+
def _common(
673+
self,
674+
at_op: str,
675+
y: Array = _undef,
676+
/,
677+
copy: bool | None = True,
678+
xp: ModuleType | None = None,
679+
_is_update: bool = True,
680+
**kwargs: Any,
681+
) -> tuple[Any, None] | tuple[None, Array]:
682+
"""Perform common prepocessing.
683+
684+
Returns
685+
-------
686+
If the operation can be resolved by at[], (return value, None)
687+
Otherwise, (None, preprocessed x)
688+
"""
689+
if self.idx is _undef:
690+
msg = (
691+
"Index has not been set.\n"
692+
"Usage: either\n"
693+
" at(x, idx).set(value)\n"
694+
"or\n"
695+
" at(x)[idx].set(value)\n"
696+
"(same for all other methods)."
697+
)
698+
raise TypeError(msg)
699+
700+
x = self.x
701+
702+
if copy is True:
703+
writeable = None
704+
elif copy is False:
705+
writeable = is_writeable_array(x)
706+
if not writeable:
707+
msg = "Cannot modify parameter in place"
708+
raise ValueError(msg)
709+
elif copy is None:
710+
writeable = is_writeable_array(x)
711+
copy = _is_update and not writeable
712+
else:
713+
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable]
714+
raise ValueError(msg)
715+
716+
if copy:
717+
try:
718+
at_ = x.at
719+
except AttributeError:
720+
# Emulate at[] behaviour for non-JAX arrays
721+
# with a copy followed by an update
722+
if xp is None:
723+
xp = array_namespace(x)
724+
# Create writeable copy of read-only numpy array
725+
x = xp.asarray(x, copy=True)
726+
if writeable is False:
727+
# A copy of a read-only numpy array is writeable
728+
writeable = None
729+
else:
730+
# Use JAX's at[] or other library that with the same duck-type API
731+
args = (y,) if y is not _undef else ()
732+
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
733+
734+
if _is_update:
735+
if writeable is None:
736+
writeable = is_writeable_array(x)
737+
if not writeable:
738+
# sparse crashes here
739+
msg = f"Array {x} has no `at` method and is read-only"
740+
raise ValueError(msg)
741+
742+
return None, x
743+
744+
def get(self, **kwargs: Any) -> Any:
745+
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
746+
that the output is either a copy or a view; it also allows passing
747+
keyword arguments to the backend.
748+
"""
749+
if kwargs.get("copy") is False:
750+
if is_array_api_obj(self.idx):
751+
# Boolean index. Note that the array API spec
752+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
753+
# does not allow for list, tuple, and tuples of slices plus one or more
754+
# one-dimensional array indices, although many backends support them.
755+
# So this check will encounter a lot of false negatives in real life,
756+
# which can be caught by testing the user code vs. array-api-strict.
757+
msg = "get() with an array index always returns a copy"
758+
raise ValueError(msg)
759+
if is_dask_array(self.x):
760+
msg = "get() on Dask arrays always returns a copy"
761+
raise ValueError(msg)
762+
763+
res, x = self._common("get", _is_update=False, **kwargs)
764+
if res is not None:
765+
return res
766+
assert x is not None
767+
return x[self.idx]
768+
769+
def set(self, y: Array, /, **kwargs: Any) -> Array:
770+
"""Apply ``x[idx] = y`` and return the update array"""
771+
res, x = self._common("set", y, **kwargs)
772+
if res is not None:
773+
return res
774+
assert x is not None
775+
x[self.idx] = y
776+
return x
777+
778+
def _iop(
779+
self,
780+
at_op: str,
781+
elwise_op: Callable[[Array, Array], Array],
782+
y: Array,
783+
/,
784+
**kwargs: Any,
785+
) -> Array:
786+
"""x[idx] += y or equivalent in-place operation on a subset of x
787+
788+
which is the same as saying
789+
x[idx] = x[idx] + y
790+
Note that this is not the same as
791+
operator.iadd(x[idx], y)
792+
Consider for example when x is a numpy array and idx is a fancy index, which
793+
triggers a deep copy on __getitem__.
794+
"""
795+
res, x = self._common(at_op, y, **kwargs)
796+
if res is not None:
797+
return res
798+
assert x is not None
799+
x[self.idx] = elwise_op(x[self.idx], y)
800+
return x
801+
802+
def add(self, y: Array, /, **kwargs: Any) -> Array:
803+
"""Apply ``x[idx] += y`` and return the updated array"""
804+
return self._iop("add", operator.add, y, **kwargs)
805+
806+
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
807+
"""Apply ``x[idx] -= y`` and return the updated array"""
808+
return self._iop("subtract", operator.sub, y, **kwargs)
809+
810+
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
811+
"""Apply ``x[idx] *= y`` and return the updated array"""
812+
return self._iop("multiply", operator.mul, y, **kwargs)
813+
814+
def divide(self, y: Array, /, **kwargs: Any) -> Array:
815+
"""Apply ``x[idx] /= y`` and return the updated array"""
816+
return self._iop("divide", operator.truediv, y, **kwargs)
817+
818+
def power(self, y: Array, /, **kwargs: Any) -> Array:
819+
"""Apply ``x[idx] **= y`` and return the updated array"""
820+
return self._iop("power", operator.pow, y, **kwargs)
821+
822+
def min(self, y: Array, /, **kwargs: Any) -> Array:
823+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
824+
xp = array_namespace(self.x)
825+
y = xp.asarray(y)
826+
return self._iop("min", xp.minimum, y, **kwargs)
827+
828+
def max(self, y: Array, /, **kwargs: Any) -> Array:
829+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
830+
xp = array_namespace(self.x)
831+
y = xp.asarray(y)
832+
return self._iop("max", xp.maximum, y, **kwargs)

‎src/array_api_extra/_lib/_compat.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,26 @@
33
from __future__ import annotations
44

55
try:
6-
from ..._array_api_compat_vendor import (
7-
array_namespace,
8-
device,
6+
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
7+
array_namespace, # pyright: ignore[reportUnknownVariableType]
8+
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
11+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
912
)
1013
except ImportError:
1114
from array_api_compat import (
1215
array_namespace,
1316
device,
17+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
18+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
19+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1420
)
1521

16-
__all__ = [
22+
__all__ = (
1723
"array_namespace",
1824
"device",
19-
]
25+
"is_array_api_obj",
26+
"is_dask_array",
27+
"is_writeable_array",
28+
)

‎src/array_api_extra/_lib/_compat.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_dask_array(x: object, /) -> bool: ...
16+
def is_writeable_array(x: object, /) -> bool: ...

‎tests/test_at.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager, suppress
4+
from importlib import import_module
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
import pytest
9+
from array_api_compat import (
10+
array_namespace,
11+
is_dask_array,
12+
is_pydata_sparse_array,
13+
is_writeable_array,
14+
)
15+
16+
from array_api_extra import at
17+
18+
if TYPE_CHECKING:
19+
from array_api_extra._lib._typing import Array
20+
21+
all_libraries = (
22+
"array_api_strict",
23+
"numpy",
24+
"numpy_readonly",
25+
"cupy",
26+
"torch",
27+
"dask.array",
28+
"sparse",
29+
"jax.numpy",
30+
)
31+
32+
33+
@pytest.fixture(params=all_libraries)
34+
def array(request):
35+
library = request.param
36+
if library == "numpy_readonly":
37+
x = np.asarray([10.0, 20.0, 30.0])
38+
x.flags.writeable = False
39+
else:
40+
try:
41+
lib = import_module(library)
42+
except ImportError:
43+
pytest.skip(f"{library} is not installed")
44+
x = lib.asarray([10.0, 20.0, 30.0])
45+
return x
46+
47+
48+
def assert_array_equal(a: Array, b: Array) -> None:
49+
xp = array_namespace(a)
50+
b = xp.asarray(b)
51+
eq = xp.all(a == b)
52+
if is_dask_array(a):
53+
eq = eq.compute()
54+
assert eq
55+
56+
57+
@contextmanager
58+
def assert_copy(array, copy: bool | None):
59+
# dask arrays are writeable, but writing to them will hot-swap the
60+
# dask graph inside the collection so that anything that references
61+
# the original graph, i.e. the input collection, won't be mutated.
62+
if copy is False and not is_writeable_array(array):
63+
with pytest.raises((TypeError, ValueError)):
64+
yield
65+
return
66+
67+
xp = array_namespace(array)
68+
array_orig = xp.asarray(array, copy=True)
69+
yield
70+
71+
expect_copy = not is_writeable_array(array) if copy is None else copy
72+
assert_array_equal(xp.all(array == array_orig), expect_copy)
73+
74+
75+
@pytest.mark.parametrize("copy", [True, False, None])
76+
@pytest.mark.parametrize(
77+
("op", "arg", "expect"),
78+
[
79+
("set", 40.0, [10.0, 40.0, 40.0]),
80+
("add", 40.0, [10.0, 60.0, 70.0]),
81+
("subtract", 100.0, [10.0, -80.0, -70.0]),
82+
("multiply", 2.0, [10.0, 40.0, 60.0]),
83+
("divide", 2.0, [10.0, 10.0, 15.0]),
84+
("power", 2.0, [10.0, 400.0, 900.0]),
85+
("min", 25.0, [10.0, 20.0, 25.0]),
86+
("max", 25.0, [10.0, 25.0, 30.0]),
87+
],
88+
)
89+
def test_update_ops(array, copy, op, arg, expect):
90+
if is_pydata_sparse_array(array):
91+
pytest.skip("at() does not support updates on sparse arrays")
92+
93+
with assert_copy(array, copy):
94+
y = getattr(at(array, slice(1, None)), op)(arg, copy=copy)
95+
assert isinstance(y, type(array))
96+
assert_array_equal(y, expect)
97+
98+
99+
@pytest.mark.parametrize("copy", [True, False, None])
100+
def test_get(array, copy):
101+
expect_copy = copy
102+
103+
# dask is mutable, but __getitem__ never returns a view
104+
if is_dask_array(array):
105+
if copy is False:
106+
with pytest.raises(ValueError, match="always returns a copy"):
107+
at(array, slice(2)).get(copy=False)
108+
return
109+
expect_copy = True
110+
111+
with assert_copy(array, expect_copy):
112+
y = at(array, slice(2)).get(copy=copy)
113+
assert isinstance(y, type(array))
114+
assert_array_equal(y, [10.0, 20.0])
115+
# Let assert_copy test that y is a view or copy
116+
with suppress(TypeError, ValueError):
117+
y[:] = 40
118+
119+
120+
def test_get_bool_indices(array):
121+
"""get() with a boolean array index always returns a copy"""
122+
# sparse violates the array API as it doesn't support
123+
# a boolean index that is another sparse array.
124+
# dask with dask index has NaN size, which complicates testing.
125+
if is_pydata_sparse_array(array) or is_dask_array(array):
126+
xp = np
127+
else:
128+
xp = array_namespace(array)
129+
idx = xp.asarray([True, False, True])
130+
131+
with pytest.raises(ValueError, match="copy"):
132+
at(array, idx).get(copy=False)
133+
134+
assert_array_equal(at(array, idx).get(), [10.0, 30.0])
135+
136+
with assert_copy(array, True):
137+
y = at(array, idx).get(copy=True)
138+
assert_array_equal(y, [10.0, 30.0])
139+
# Let assert_copy test that y is a view or copy
140+
with suppress(TypeError, ValueError):
141+
y[:] = 40
142+
143+
144+
def test_copy_invalid():
145+
a = np.asarray([1, 2, 3])
146+
with pytest.raises(ValueError, match="copy"):
147+
at(a, 0).set(4, copy="invalid")
148+
149+
150+
def test_xp():
151+
a = np.asarray([1, 2, 3])
152+
b = at(a, 0).set(4, xp=np)
153+
assert_array_equal(b, [4, 2, 3])

‎vendor_tests/test_vendor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,21 @@
55

66

77
def test_vendor_compat():
8-
from ._array_api_compat_vendor import array_namespace
8+
from ._array_api_compat_vendor import ( # type: ignore[attr-defined]
9+
array_namespace,
10+
device,
11+
is_array_api_obj,
12+
is_dask_array,
13+
is_writeable_array,
14+
)
915

1016
x = xp.asarray([1, 2, 3])
1117
assert array_namespace(x) is xp # type: ignore[no-untyped-call]
18+
device(x)
19+
assert is_array_api_obj(x)
20+
assert not is_array_api_obj(123)
21+
assert not is_dask_array(x)
22+
assert is_writeable_array(x)
1223

1324

1425
def test_vendor_extra():

0 commit comments

Comments
 (0)
Please sign in to comment.