Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 712dd9e

Browse files
committedApr 25, 2025·
ENH: jax_autojit
1 parent 4425d14 commit 712dd9e

File tree

8 files changed

+628
-104
lines changed

8 files changed

+628
-104
lines changed
 

‎src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 259 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,24 @@
22

33
from __future__ import annotations
44

5+
import io
56
import math
6-
from collections.abc import Generator, Iterable
7+
import pickle
8+
import types
9+
from collections.abc import Callable, Generator, Iterable
10+
from functools import wraps
711
from types import ModuleType
8-
from typing import TYPE_CHECKING, cast
12+
from typing import (
13+
TYPE_CHECKING,
14+
Any,
15+
ClassVar,
16+
Generic,
17+
Literal,
18+
ParamSpec,
19+
TypeAlias,
20+
TypeVar,
21+
cast,
22+
)
923

1024
from . import _compat
1125
from ._compat import (
@@ -19,8 +33,16 @@
1933
from ._typing import Array
2034

2135
if TYPE_CHECKING: # pragma: no cover
22-
# TODO import from typing (requires Python >=3.13)
23-
from typing_extensions import TypeIs
36+
# TODO import from typing (requires Python >=3.12 and >=3.13)
37+
from typing_extensions import TypeIs, override
38+
else:
39+
40+
def override(func):
41+
return func
42+
43+
44+
P = ParamSpec("P")
45+
T = TypeVar("T")
2446

2547

2648
__all__ = [
@@ -29,8 +51,11 @@
2951
"eager_shape",
3052
"in1d",
3153
"is_python_scalar",
54+
"jax_autojit",
3255
"mean",
3356
"meta_namespace",
57+
"pickle_flatten",
58+
"pickle_unflatten",
3459
]
3560

3661

@@ -306,3 +331,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
306331
out["boolean indexing"] = True
307332
out["data-dependent shapes"] = True
308333
return out
334+
335+
336+
_BASIC_PICKLED_TYPES = frozenset((
337+
bool, int, float, complex, str, bytes, bytearray,
338+
list, tuple, dict, set, frozenset, range, slice,
339+
types.NoneType, types.EllipsisType,
340+
)) # fmt: skip
341+
_BASIC_REST_TYPES = frozenset((
342+
type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType
343+
)) # fmt: skip
344+
345+
FlattenRest: TypeAlias = tuple[object, ...]
346+
347+
348+
def pickle_flatten(
349+
obj: object, cls: type[T] | tuple[type[T], ...]
350+
) -> tuple[list[T], FlattenRest]:
351+
"""
352+
Use the pickle machinery to extract objects out of an arbitrary container.
353+
354+
Unlike regular ``pickle.dumps``, this function always succeeds.
355+
356+
Parameters
357+
----------
358+
obj : object
359+
The object to pickle.
360+
cls : type | tuple[type, ...]
361+
One or multiple classes to extract from the object.
362+
The instances of these classes inside ``obj`` will not be pickled.
363+
364+
Returns
365+
-------
366+
instances : list[cls]
367+
All instances of ``cls`` found inside ``obj`` (not pickled).
368+
rest
369+
Opaque object containing the pickled bytes plus all other objects where
370+
``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised.
371+
These are unpickleable objects, types, modules, and functions.
372+
373+
This object is *typically* hashable save for fairly exotic objects
374+
that are neither pickleable nor hashable.
375+
376+
This object is pickleable if everything except ``instances`` was pickleable
377+
in the input object.
378+
379+
See Also
380+
--------
381+
pickle_unflatten : Reverse function.
382+
383+
Examples
384+
--------
385+
>>> class A:
386+
... def __repr__(self):
387+
... return "<A>"
388+
>>> class NS:
389+
... def __repr__(self):
390+
... return "<NS>"
391+
... def __reduce__(self):
392+
... assert False, "not serializable"
393+
>>> obj = {1: A(), 2: [A(), NS(), A()]}
394+
>>> instances, rest = pickle_flatten(obj, A)
395+
>>> instances
396+
[<A>, <A>, <A>]
397+
>>> pickle_unflatten(instances, rest)
398+
{1: <A>, 2: [<A>, <NS>, <A>]}
399+
400+
This can be also used to swap inner objects; the only constraint is that
401+
the number of objects in and out must be the same:
402+
403+
>>> pickle_unflatten(["foo", "bar", "baz"], rest)
404+
{1: "foo", 2: ["bar", <NS>, "baz"]}
405+
"""
406+
instances: list[T] = []
407+
rest: list[object] = []
408+
409+
class Pickler(pickle.Pickler): # numpydoc ignore=GL08
410+
"""
411+
Use the `pickle.Pickler.persistent_id` hook to extract objects.
412+
"""
413+
414+
@override
415+
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
416+
if isinstance(obj, cls):
417+
instances.append(obj) # type: ignore[arg-type]
418+
return 0
419+
420+
typ_ = type(obj)
421+
if typ_ in _BASIC_PICKLED_TYPES: # No subclasses!
422+
# If obj is a collection, recursively descend inside it
423+
return None
424+
if typ_ in _BASIC_REST_TYPES:
425+
rest.append(obj)
426+
return 1
427+
428+
try:
429+
# Note: a class that defines __slots__ without defining __getstate__
430+
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
431+
_ = obj.__reduce_ex__(5)
432+
except Exception: # pylint: disable=broad-exception-caught
433+
rest.append(obj)
434+
return 1
435+
436+
# Object can be pickled. Let the Pickler recursively descend inside it.
437+
return None
438+
439+
f = io.BytesIO()
440+
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL)
441+
p.dump(obj)
442+
return instances, (f.getvalue(), *rest)
443+
444+
445+
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
446+
"""
447+
Reverse of ``pickle_flatten``.
448+
449+
Parameters
450+
----------
451+
instances : Iterable
452+
Inner objects to be reinserted into the flattened container.
453+
rest : FlattenRest
454+
Extra bits, as returned by ``pickle_flatten``.
455+
456+
Returns
457+
-------
458+
object
459+
The outer object originally passed to ``pickle_flatten`` after a
460+
pickle->unpickle round-trip.
461+
462+
See Also
463+
--------
464+
pickle_flatten : Serializing function.
465+
pickle.loads : Standard unpickle function.
466+
467+
Notes
468+
-----
469+
The `instances` iterable must yield at least the same number of elements as the ones
470+
returned by ``pickle_without``, but the elements do not need to be the same objects
471+
or even the same types of objects. Excess elements, if any, will be left untouched.
472+
"""
473+
iters = iter(instances), iter(rest)
474+
pik = cast(bytes, next(iters[1]))
475+
476+
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08
477+
"""Mirror of the overridden Pickler in pickle_flatten."""
478+
479+
@override
480+
def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
481+
try:
482+
return next(iters[pid])
483+
except StopIteration as e:
484+
msg = "Not enough objects to unpickle"
485+
raise ValueError(msg) from e
486+
487+
f = io.BytesIO(pik)
488+
return Unpickler(f).load()
489+
490+
491+
class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
492+
"""
493+
Helper of :func:`jax_autojit`.
494+
495+
Wrap arbitrary inputs and outputs of the jitted function and
496+
convert them to/from PyTrees.
497+
"""
498+
499+
obj: T
500+
_registered: ClassVar[bool] = False
501+
__slots__: tuple[str, ...] = ("obj",)
502+
503+
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
504+
self._register()
505+
self.obj = obj
506+
507+
@classmethod
508+
def _register(cls): # numpydoc ignore=SS06
509+
"""
510+
Register upon first use instead of at import time, to avoid
511+
globally importing JAX.
512+
"""
513+
if not cls._registered:
514+
import jax
515+
516+
jax.tree_util.register_pytree_node(
517+
cls,
518+
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
519+
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
520+
)
521+
cls._registered = True
522+
523+
524+
def jax_autojit(
525+
func: Callable[P, T],
526+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
527+
"""
528+
Wrap `func` with ``jax.jit``, with the following differences:
529+
530+
- Python scalar arguments and return values are not automatically converted to
531+
``jax.Array`` objects.
532+
- All non-array arguments are automatically treated as static.
533+
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
534+
``pickle``.
535+
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
536+
tuple/list/dict, but can be any object serializable with ``pickle``.
537+
- Automatically descend into non-array arguments and find ``jax.Array`` objects
538+
inside them, then rebuild the arguments when entering `func`, swapping the JAX
539+
concrete arrays with tracer objects.
540+
- Automatically descend into non-array return values and find ``jax.Array`` objects
541+
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
542+
tracer objects with concrete arrays.
543+
544+
See Also
545+
--------
546+
jax.jit : JAX JIT compilation function.
547+
"""
548+
import jax
549+
550+
@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
551+
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
552+
wargs: _AutoJITWrapper[Any],
553+
) -> _AutoJITWrapper[T]:
554+
args, kwargs = wargs.obj
555+
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
556+
return _AutoJITWrapper(res)
557+
558+
@wraps(func)
559+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
560+
wargs = _AutoJITWrapper((args, kwargs))
561+
return inner(wargs).obj
562+
563+
return outer

‎src/array_api_extra/testing.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
from __future__ import annotations
88

99
import contextlib
10-
from collections.abc import Callable, Iterable, Iterator, Sequence
10+
import enum
11+
import warnings
12+
from collections.abc import Callable, Iterator, Sequence
1113
from functools import wraps
1214
from types import ModuleType
1315
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
1416

1517
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
18+
from ._lib._utils._helpers import jax_autojit, pickle_flatten, pickle_unflatten
1619

1720
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
1821

@@ -26,7 +29,7 @@
2629
# Sphinx hacks
2730
SchedulerGetCallable = object
2831

29-
def override(func: object) -> object:
32+
def override(func):
3033
return func
3134

3235

@@ -36,13 +39,22 @@ def override(func: object) -> object:
3639
_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[explicit-any]
3740

3841

42+
class Deprecated(enum.Enum):
43+
"""Unique type for deprecated parameters."""
44+
45+
DEPRECATED = 1
46+
47+
48+
DEPRECATED = Deprecated.DEPRECATED
49+
50+
3951
def lazy_xp_function( # type: ignore[explicit-any]
4052
func: Callable[..., Any],
4153
*,
4254
allow_dask_compute: bool | int = False,
4355
jax_jit: bool = True,
44-
static_argnums: int | Sequence[int] | None = None,
45-
static_argnames: str | Iterable[str] | None = None,
56+
static_argnums: Deprecated = DEPRECATED,
57+
static_argnames: Deprecated = DEPRECATED,
4658
) -> None: # numpydoc ignore=GL07
4759
"""
4860
Tag a function to be tested on lazy backends.
@@ -82,16 +94,30 @@ def lazy_xp_function( # type: ignore[explicit-any]
8294
Default: False, meaning that `func` must be fully lazy and never materialize the
8395
graph.
8496
jax_jit : bool, optional
85-
Set to True to replace `func` with ``jax.jit(func)`` after calling the
86-
:func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False
87-
if `func` is only compatible with eager (non-jitted) JAX. Default: True.
88-
static_argnums : int | Sequence[int], optional
89-
Passed to jax.jit. Positional arguments to treat as static (compile-time
90-
constant). Default: infer from `static_argnames` using
91-
`inspect.signature(func)`.
92-
static_argnames : str | Iterable[str], optional
93-
Passed to jax.jit. Named arguments to treat as static (compile-time constant).
94-
Default: infer from `static_argnums` using `inspect.signature(func)`.
97+
Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after
98+
calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``.
99+
Set to False if `func` is only compatible with eager (non-jitted) JAX.
100+
101+
Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX
102+
arrays are treated as static; the function can accept and return arbitrary
103+
wrappers around JAX arrays. This difference is because, in real life, most users
104+
won't wrap the function directly with ``jax.jit`` but rather they will use it
105+
within their own code, which is itself then wrapped by ``jax.jit``, and
106+
internally consume the function's outputs.
107+
108+
In other words, the pattern that is being tested is::
109+
110+
>>> @jax.jit
111+
... def user_func(x):
112+
... y = user_prepares_inputs(x)
113+
... z = func(y, some_static_arg=True)
114+
... return user_consumes(z)
115+
116+
Default: True.
117+
static_argnums :
118+
Deprecated; ignored
119+
static_argnames :
120+
Deprecated; ignored
95121
96122
See Also
97123
--------
@@ -108,7 +134,7 @@ def lazy_xp_function( # type: ignore[explicit-any]
108134
109135
def test_myfunc(xp):
110136
a = xp.asarray([1, 2])
111-
# When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)`
137+
# When xp=jax.numpy, this is similar to `b = jax.jit(myfunc)(a)`
112138
# When xp=dask.array, crash on compute() or persist()
113139
b = myfunc(a)
114140
@@ -168,12 +194,20 @@ def test_myfunc(xp):
168194
b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
169195
c = naked.myfunc(a) # This is not
170196
"""
197+
if static_argnums is not DEPRECATED or static_argnames is not DEPRECATED:
198+
warnings.warn(
199+
(
200+
"The `static_argnums` and `static_argnames` parameters are deprecated "
201+
"and ignored. They will be removed in a future version."
202+
),
203+
DeprecationWarning,
204+
stacklevel=2,
205+
)
171206
tags = {
172207
"allow_dask_compute": allow_dask_compute,
173208
"jax_jit": jax_jit,
174-
"static_argnums": static_argnums,
175-
"static_argnames": static_argnames,
176209
}
210+
177211
try:
178212
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
179213
except AttributeError: # @cython.vectorize
@@ -247,19 +281,9 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
247281
monkeypatch.setattr(mod, name, wrapped)
248282

249283
elif is_jax_namespace(xp):
250-
import jax
251-
252284
for mod, name, func, tags in iter_tagged():
253285
if tags["jax_jit"]:
254-
# suppress unused-ignore to run mypy in -e lint as well as -e dev
255-
wrapped = cast( # type: ignore[explicit-any]
256-
Callable[..., Any],
257-
jax.jit(
258-
func,
259-
static_argnums=tags["static_argnums"],
260-
static_argnames=tags["static_argnames"],
261-
),
262-
)
286+
wrapped = jax_autojit(func)
263287
monkeypatch.setattr(mod, name, wrapped)
264288

265289

@@ -308,6 +332,7 @@ def _dask_wrap(
308332
After the function returns, materialize the graph in order to re-raise exceptions.
309333
"""
310334
import dask
335+
import dask.array as da
311336

312337
func_name = getattr(func, "__name__", str(func))
313338
n_str = f"only up to {n}" if n else "no"
@@ -327,6 +352,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
327352
# Block until the graph materializes and reraise exceptions. This allows
328353
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
329354
# not work on scheduler='distributed', as it would not block.
330-
return dask.persist(out, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
355+
arrays, rest = pickle_flatten(out, da.Array)
356+
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
357+
return pickle_unflatten(arrays, rest) # pyright: ignore[reportUnknownArgumentType]
331358

332359
return wrapper

‎tests/conftest.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,7 @@ def xp(
148148
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
149149

150150
if library.like(Backend.JAX):
151-
import jax
152-
153-
# suppress unused-ignore to run mypy in -e lint as well as -e dev
154-
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
155-
156-
if library == Backend.JAX_GPU:
157-
try:
158-
device = jax.devices("cuda")[0]
159-
except RuntimeError:
160-
pytest.skip("no CUDA device available")
161-
else:
162-
device = jax.devices("cpu")[0]
163-
jax.config.update("jax_default_device", device)
151+
_setup_jax(library)
164152

165153
elif library == Backend.TORCH_GPU:
166154
import torch.cuda
@@ -175,6 +163,22 @@ def xp(
175163
yield xp
176164

177165

166+
def _setup_jax(library: Backend) -> None:
167+
import jax
168+
169+
# suppress unused-ignore to run mypy in -e lint as well as -e dev
170+
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
171+
172+
if library == Backend.JAX_GPU:
173+
try:
174+
device = jax.devices("cuda")[0]
175+
except RuntimeError:
176+
pytest.skip("no CUDA device available")
177+
else:
178+
device = jax.devices("cpu")[0]
179+
jax.config.update("jax_default_device", device)
180+
181+
178182
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
179183
def da(
180184
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
@@ -186,6 +190,17 @@ def da(
186190
return xp
187191

188192

193+
@pytest.fixture(params=[Backend.JAX, Backend.JAX_GPU])
194+
def jnp(
195+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
196+
) -> ModuleType: # numpydoc ignore=PR01,RT01
197+
"""Variant of the `xp` fixture that only yields jax.numpy."""
198+
xp = pytest.importorskip("jax.numpy")
199+
_setup_jax(request.param)
200+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
201+
return xp
202+
203+
189204
@pytest.fixture
190205
def device(
191206
library: Backend, xp: ModuleType

‎tests/test_at.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
import pickle
32
from collections.abc import Callable, Generator
43
from contextlib import contextmanager
54
from types import ModuleType
@@ -41,28 +40,11 @@ def at_op(
4140
just a workaround for when one wants to apply jax.jit to `at()` directly,
4241
which is not a common use case.
4342
"""
44-
if isinstance(idx, (slice | tuple)):
45-
return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp)
46-
return _at_op(x, idx, None, op, y, copy=copy, xp=xp)
47-
48-
49-
def _at_op(
50-
x: Array,
51-
idx: SetIndex | None,
52-
idx_pickle: bytes | None,
53-
op: _AtOp,
54-
y: Array | object,
55-
copy: bool | None,
56-
xp: ModuleType | None = None,
57-
) -> Array:
58-
"""jitted helper of at_op"""
59-
if idx_pickle:
60-
idx = pickle.loads(idx_pickle)
61-
meth = cast(Callable[..., Array], getattr(at(x, cast(SetIndex, idx)), op.value)) # type: ignore[explicit-any]
43+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[explicit-any]
6244
return meth(y, copy=copy, xp=xp)
6345

6446

65-
lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp"))
47+
lazy_xp_function(at_op)
6648

6749

6850
@contextmanager

‎tests/test_funcs.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@
3737
# some xp backends are untyped
3838
# mypy: disable-error-code=no-untyped-def
3939

40-
lazy_xp_function(apply_where, static_argnums=(2, 3), static_argnames="xp")
41-
lazy_xp_function(atleast_nd, static_argnames=("ndim", "xp"))
42-
lazy_xp_function(cov, static_argnames="xp")
43-
lazy_xp_function(create_diagonal, static_argnames=("offset", "xp"))
44-
lazy_xp_function(expand_dims, static_argnames=("axis", "xp"))
45-
lazy_xp_function(kron, static_argnames="xp")
46-
lazy_xp_function(nunique, static_argnames="xp")
47-
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
40+
lazy_xp_function(apply_where)
41+
lazy_xp_function(atleast_nd)
42+
lazy_xp_function(cov)
43+
lazy_xp_function(create_diagonal)
44+
lazy_xp_function(expand_dims)
45+
lazy_xp_function(kron)
46+
lazy_xp_function(nunique)
47+
lazy_xp_function(pad)
4848
# FIXME calls in1d which calls xp.unique_values without size
49-
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
50-
lazy_xp_function(sinc, static_argnames="xp")
49+
lazy_xp_function(setdiff1d, jax_jit=False)
50+
lazy_xp_function(sinc)
5151

5252

5353
class TestApplyWhere:

‎tests/test_helpers.py

Lines changed: 197 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from types import ModuleType
2-
from typing import cast
2+
from typing import TYPE_CHECKING, Generic, TypeVar, cast
33

44
import numpy as np
55
import pytest
@@ -13,18 +13,31 @@
1313
capabilities,
1414
eager_shape,
1515
in1d,
16+
jax_autojit,
1617
meta_namespace,
1718
ndindex,
19+
pickle_flatten,
20+
pickle_unflatten,
1821
)
1922
from array_api_extra._lib._utils._typing import Array, Device, DType
2023
from array_api_extra.testing import lazy_xp_function
2124

2225
from .conftest import np_compat
2326

27+
if TYPE_CHECKING:
28+
# TODO import from typing (requires Python >=3.12)
29+
from typing_extensions import override
30+
else:
31+
32+
def override(func):
33+
return func
34+
2435
# mypy: disable-error-code=no-untyped-usage
2536

37+
T = TypeVar("T")
38+
2639
# FIXME calls xp.unique_values without size
27-
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
40+
lazy_xp_function(in1d, jax_jit=False)
2841

2942

3043
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse")
@@ -204,3 +217,185 @@ def test_capabilities(xp: ModuleType):
204217
if xp.__array_api_version__ >= "2024.12":
205218
expect.add("max dimensions")
206219
assert capabilities(xp).keys() == expect
220+
221+
222+
class Wrapper(Generic[T]):
223+
"""Trivial opaque wrapper. Must be pickleable."""
224+
225+
x: T
226+
# __slots__ make this object serializable with __reduce_ex__(5),
227+
# but not with __reduce__
228+
__slots__: tuple[str, ...] = ("x",)
229+
230+
def __init__(self, x: T):
231+
self.x = x
232+
233+
# Note: this makes the object not hashable
234+
@override
235+
def __eq__(self, other: object) -> bool:
236+
return isinstance(other, Wrapper) and self.x == other.x
237+
238+
239+
class TestPickleFlatten:
240+
def test_roundtrip(self):
241+
class NotSerializable:
242+
@override
243+
def __reduce__(self) -> tuple[object, ...]:
244+
raise NotImplementedError()
245+
246+
# Note: NotHashable() instances can be reduced to an
247+
# unserializable local class
248+
class NotHashable:
249+
@override
250+
def __eq__(self, other: object) -> bool:
251+
return isinstance(other, type(self)) and other.__dict__ == self.__dict__
252+
253+
with pytest.raises(TypeError):
254+
_ = hash(NotHashable())
255+
256+
# Extracted objects need be neither pickleable nor serializable
257+
class C(NotSerializable, NotHashable):
258+
x: int
259+
260+
def __init__(self, x: int):
261+
self.x = x
262+
263+
class D(C):
264+
pass
265+
266+
c1 = C(1)
267+
c2 = C(2)
268+
d3 = D(3)
269+
270+
# An assorted bunch of opaque containers, standard containers,
271+
# non-serializable objects, and non-hashable objects (but not at the same time)
272+
obj = Wrapper([1, c1, {2: (c2, {NotSerializable()})}, NotHashable(), d3])
273+
instances, rest = pickle_flatten(obj, C)
274+
275+
assert instances == [c1, c2, d3]
276+
obj2 = pickle_unflatten(instances, rest)
277+
assert obj2 == obj
278+
279+
def test_swap_objects(self):
280+
class C:
281+
pass
282+
283+
obj = [1, C(), {2: (C(), {C()})}]
284+
_, rest = pickle_flatten(obj, C)
285+
obj2 = pickle_unflatten(["foo", "bar", "baz"], rest)
286+
assert obj2 == [1, "foo", {2: ("bar", {"baz"})}]
287+
288+
def test_multi_class(self):
289+
class C:
290+
pass
291+
292+
class D:
293+
pass
294+
295+
c, d = C(), D()
296+
instances, _ = pickle_flatten([c, d], (C, D))
297+
assert len(instances) == 2
298+
assert instances[0] is c
299+
assert instances[1] is d
300+
301+
def test_no_class(self):
302+
obj = {1: "foo", 2: (3, 4)}
303+
instances, rest = pickle_flatten(obj, ()) # type: ignore[var-annotated]
304+
assert instances == []
305+
obj2 = pickle_unflatten([], rest)
306+
assert obj2 == obj
307+
308+
def test_flattened_stream(self):
309+
"""
310+
Test that multiple calls to flatten() can feed into the same stream of instances
311+
"""
312+
obj1 = Wrapper(1)
313+
obj2 = [Wrapper(2), Wrapper(3)]
314+
instances1, rest1 = pickle_flatten(obj1, Wrapper)
315+
instances2, rest2 = pickle_flatten(obj2, Wrapper)
316+
it = iter(instances1 + instances2 + [Wrapper(4)]) # pyright: ignore[reportUnknownArgumentType]
317+
assert pickle_unflatten(it, rest1) == obj1 # pyright: ignore[reportUnknownArgumentType]
318+
assert pickle_unflatten(it, rest2) == obj2 # pyright: ignore[reportUnknownArgumentType]
319+
assert list(it) == [Wrapper(4)] # pyright: ignore[reportUnknownArgumentType]
320+
321+
def test_too_short(self):
322+
obj = [Wrapper(1), Wrapper(2)]
323+
instances, rest = pickle_flatten(obj, Wrapper)
324+
with pytest.raises(ValueError, match="Not enough"):
325+
pickle_unflatten(instances[:1], rest) # pyright: ignore[reportUnknownArgumentType]
326+
327+
def test_recursion(self):
328+
obj: list[object] = [Wrapper(1)]
329+
obj.append(obj)
330+
331+
instances, rest = pickle_flatten(obj, Wrapper)
332+
assert instances == [Wrapper(1)]
333+
334+
obj2 = pickle_unflatten(instances, rest) # pyright: ignore[reportUnknownArgumentType]
335+
assert len(obj2) == 2
336+
assert obj2[0] is obj[0]
337+
assert obj2[1] is obj2
338+
339+
340+
class TestJAXAutoJIT:
341+
def test_basic(self, jnp: ModuleType):
342+
@jax_autojit
343+
def f(x: Array, k: object = False) -> Array:
344+
return x + 1 if k else x - 1
345+
346+
# Basic recognition of static_argnames
347+
xp_assert_equal(f(jnp.asarray([1, 2])), jnp.asarray([0, 1]))
348+
xp_assert_equal(f(jnp.asarray([1, 2]), False), jnp.asarray([0, 1]))
349+
xp_assert_equal(f(jnp.asarray([1, 2]), True), jnp.asarray([2, 3]))
350+
xp_assert_equal(f(jnp.asarray([1, 2]), 1), jnp.asarray([2, 3]))
351+
352+
# static argument is not an ArrayLike
353+
xp_assert_equal(f(jnp.asarray([1, 2]), "foo"), jnp.asarray([2, 3]))
354+
355+
# static argument is not hashable, but serializable
356+
xp_assert_equal(f(jnp.asarray([1, 2]), ["foo"]), jnp.asarray([2, 3]))
357+
358+
def test_wrapper(self, jnp: ModuleType):
359+
@jax_autojit
360+
def f(w: Wrapper[Array]) -> Wrapper[Array]:
361+
return Wrapper(w.x + 1)
362+
363+
inp = Wrapper(jnp.asarray([1, 2]))
364+
out = f(inp).x
365+
xp_assert_equal(out, jnp.asarray([2, 3]))
366+
367+
def test_static_hashable(self, jnp: ModuleType):
368+
"""Static argument/return value is hashable, but not serializable"""
369+
370+
class C:
371+
def __reduce__(self) -> object: # type: ignore[explicit-override,override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
372+
raise Exception()
373+
374+
@jax_autojit
375+
def f(x: object) -> object:
376+
return x
377+
378+
inp = C()
379+
out = f(inp)
380+
assert out is inp
381+
382+
# Serializable opaque input contains non-serializable object plus array
383+
inp = Wrapper((C(), jnp.asarray([1, 2])))
384+
out = f(inp)
385+
assert isinstance(out, Wrapper)
386+
assert out.x[0] is inp.x[0]
387+
assert out.x[1] is not inp.x[1]
388+
xp_assert_equal(out.x[1], inp.x[1]) # pyright: ignore[reportUnknownArgumentType]
389+
390+
def test_arraylikes_are_static(self):
391+
pytest.importorskip("jax")
392+
393+
@jax_autojit
394+
def f(x: list[int]) -> list[int]:
395+
assert isinstance(x, list)
396+
assert x == [1, 2]
397+
return [3, 4]
398+
399+
out = f([1, 2])
400+
assert isinstance(out, list)
401+
assert out == [3, 4]

‎tests/test_lazy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from array_api_extra._lib._utils._typing import Array, Device
1616
from array_api_extra.testing import lazy_xp_function
1717

18-
lazy_xp_function(
19-
lazy_apply, static_argnames=("func", "shape", "dtype", "as_numpy", "xp")
20-
)
18+
lazy_xp_function(lazy_apply)
2119

2220
as_numpy = pytest.mark.parametrize(
2321
"as_numpy",
@@ -386,7 +384,7 @@ def eager(
386384
)
387385

388386

389-
lazy_xp_function(check_lazy_apply_kwargs, static_argnames=("expect_cls", "as_numpy"))
387+
lazy_xp_function(check_lazy_apply_kwargs)
390388

391389

392390
@as_numpy

‎tests/test_testing.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -185,32 +185,24 @@ def static_params(x: Array, n: int, flag: bool = False) -> Array:
185185
return x * 3.0
186186

187187

188-
def static_params1(x: Array, n: int, flag: bool = False) -> Array:
189-
return static_params(x, n, flag)
188+
lazy_xp_function(static_params)
190189

191190

192-
def static_params2(x: Array, n: int, flag: bool = False) -> Array:
193-
return static_params(x, n, flag)
194-
195-
196-
def static_params3(x: Array, n: int, flag: bool = False) -> Array:
197-
return static_params(x, n, flag)
198-
199-
200-
lazy_xp_function(static_params1, static_argnums=(1, 2))
201-
lazy_xp_function(static_params2, static_argnames=("n", "flag"))
202-
lazy_xp_function(static_params3, static_argnums=1, static_argnames="flag")
191+
def test_lazy_xp_function_static_params(xp: ModuleType):
192+
x = xp.asarray([1.0, 2.0])
193+
xp_assert_equal(static_params(x, 1), xp.asarray([3.0, 6.0]))
194+
xp_assert_equal(static_params(x, 1, True), xp.asarray([2.0, 4.0]))
195+
xp_assert_equal(static_params(x, 1, False), xp.asarray([3.0, 6.0]))
196+
xp_assert_equal(static_params(x, 0, False), xp.asarray([3.0, 6.0]))
197+
xp_assert_equal(static_params(x, 1, flag=True), xp.asarray([2.0, 4.0]))
198+
xp_assert_equal(static_params(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
203199

204200

205-
@pytest.mark.parametrize("func", [static_params1, static_params2, static_params3])
206-
def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Array]): # type: ignore[explicit-any]
207-
x = xp.asarray([1.0, 2.0])
208-
xp_assert_equal(func(x, 1), xp.asarray([3.0, 6.0]))
209-
xp_assert_equal(func(x, 1, True), xp.asarray([2.0, 4.0]))
210-
xp_assert_equal(func(x, 1, False), xp.asarray([3.0, 6.0]))
211-
xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0]))
212-
xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0]))
213-
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
201+
def test_lazy_xp_function_deprecated_static_argnames():
202+
with pytest.warns(DeprecationWarning, match="static_argnames"):
203+
lazy_xp_function(static_params, static_argnames=["flag"]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
204+
with pytest.warns(DeprecationWarning, match="static_argnums"):
205+
lazy_xp_function(static_params, static_argnums=[1]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
214206

215207

216208
try:
@@ -273,6 +265,66 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType):
273265
_ = dask_raises(x)
274266

275267

268+
class Wrapper:
269+
"""Trivial opaque wrapper. Must be pickleable."""
270+
271+
x: Array
272+
273+
def __init__(self, x: Array):
274+
self.x = x
275+
276+
277+
def check_opaque_wrapper(w: Wrapper, xp: ModuleType) -> Wrapper:
278+
assert isinstance(w, Wrapper)
279+
assert array_namespace(w.x) == xp
280+
return Wrapper(w.x + 1)
281+
282+
283+
lazy_xp_function(check_opaque_wrapper)
284+
285+
286+
def test_lazy_xp_function_opaque_wrappers(xp: ModuleType):
287+
"""
288+
Test that function input and output can be wrapped into arbitrary
289+
serializable Python objects, even if jax.jit does not support them.
290+
"""
291+
x = xp.asarray([1, 2])
292+
xp2 = array_namespace(x) # Revert NUMPY_READONLY to array_api_compat.numpy
293+
res = check_opaque_wrapper(Wrapper(x), xp2)
294+
xp_assert_equal(res.x, xp.asarray([2, 3]))
295+
296+
297+
def test_lazy_xp_function_opaque_wrappers_eagerly_raise(da: ModuleType):
298+
"""
299+
Like `test_lazy_xp_function_eagerly_raises`, but the returned object is
300+
wrapped in an opaque wrapper.
301+
"""
302+
x = da.arange(3)
303+
with pytest.raises(ValueError, match="Hello world"):
304+
_ = Wrapper(dask_raises(x))
305+
306+
307+
def check_recursive(x: list[object]) -> list[object]:
308+
assert isinstance(x, list)
309+
assert x[1] is x
310+
y: list[object] = [cast(Array, x[0]) + 1]
311+
y.append(y)
312+
return y
313+
314+
315+
lazy_xp_function(check_recursive)
316+
317+
318+
def test_lazy_xp_function_recursive(xp: ModuleType):
319+
"""Test that inputs and outputs can be recursive data structures."""
320+
x: list[object] = [xp.asarray([1, 2])]
321+
x.append(x)
322+
y = check_recursive(x)
323+
assert isinstance(y, list)
324+
xp_assert_equal(cast(Array, y[0]), xp.asarray([2, 3]))
325+
assert y[1] is y
326+
327+
276328
wrapped = ModuleType("wrapped")
277329
naked = ModuleType("naked")
278330

0 commit comments

Comments
 (0)
Please sign in to comment.