Skip to content

Commit 28a364d

Browse files
authored
Merge pull request #284 from crusaderky/autojit
ENH: `jax_autojit`
1 parent 6fc85ba commit 28a364d

File tree

8 files changed

+628
-104
lines changed

8 files changed

+628
-104
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 259 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,24 @@
22

33
from __future__ import annotations
44

5+
import io
56
import math
6-
from collections.abc import Generator, Iterable
7+
import pickle
8+
import types
9+
from collections.abc import Callable, Generator, Iterable
10+
from functools import wraps
711
from types import ModuleType
8-
from typing import TYPE_CHECKING, cast
12+
from typing import (
13+
TYPE_CHECKING,
14+
Any,
15+
ClassVar,
16+
Generic,
17+
Literal,
18+
ParamSpec,
19+
TypeAlias,
20+
TypeVar,
21+
cast,
22+
)
923

1024
from . import _compat
1125
from ._compat import (
@@ -19,8 +33,16 @@
1933
from ._typing import Array
2034

2135
if TYPE_CHECKING: # pragma: no cover
22-
# TODO import from typing (requires Python >=3.13)
23-
from typing_extensions import TypeIs
36+
# TODO import from typing (requires Python >=3.12 and >=3.13)
37+
from typing_extensions import TypeIs, override
38+
else:
39+
40+
def override(func):
41+
return func
42+
43+
44+
P = ParamSpec("P")
45+
T = TypeVar("T")
2446

2547

2648
__all__ = [
@@ -29,8 +51,11 @@
2951
"eager_shape",
3052
"in1d",
3153
"is_python_scalar",
54+
"jax_autojit",
3255
"mean",
3356
"meta_namespace",
57+
"pickle_flatten",
58+
"pickle_unflatten",
3459
]
3560

3661

@@ -302,3 +327,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
302327
out = out.copy()
303328
out["boolean indexing"] = False
304329
return out
330+
331+
332+
_BASIC_PICKLED_TYPES = frozenset((
333+
bool, int, float, complex, str, bytes, bytearray,
334+
list, tuple, dict, set, frozenset, range, slice,
335+
types.NoneType, types.EllipsisType,
336+
)) # fmt: skip
337+
_BASIC_REST_TYPES = frozenset((
338+
type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType
339+
)) # fmt: skip
340+
341+
FlattenRest: TypeAlias = tuple[object, ...]
342+
343+
344+
def pickle_flatten(
345+
obj: object, cls: type[T] | tuple[type[T], ...]
346+
) -> tuple[list[T], FlattenRest]:
347+
"""
348+
Use the pickle machinery to extract objects out of an arbitrary container.
349+
350+
Unlike regular ``pickle.dumps``, this function always succeeds.
351+
352+
Parameters
353+
----------
354+
obj : object
355+
The object to pickle.
356+
cls : type | tuple[type, ...]
357+
One or multiple classes to extract from the object.
358+
The instances of these classes inside ``obj`` will not be pickled.
359+
360+
Returns
361+
-------
362+
instances : list[cls]
363+
All instances of ``cls`` found inside ``obj`` (not pickled).
364+
rest
365+
Opaque object containing the pickled bytes plus all other objects where
366+
``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised.
367+
These are unpickleable objects, types, modules, and functions.
368+
369+
This object is *typically* hashable save for fairly exotic objects
370+
that are neither pickleable nor hashable.
371+
372+
This object is pickleable if everything except ``instances`` was pickleable
373+
in the input object.
374+
375+
See Also
376+
--------
377+
pickle_unflatten : Reverse function.
378+
379+
Examples
380+
--------
381+
>>> class A:
382+
... def __repr__(self):
383+
... return "<A>"
384+
>>> class NS:
385+
... def __repr__(self):
386+
... return "<NS>"
387+
... def __reduce__(self):
388+
... assert False, "not serializable"
389+
>>> obj = {1: A(), 2: [A(), NS(), A()]}
390+
>>> instances, rest = pickle_flatten(obj, A)
391+
>>> instances
392+
[<A>, <A>, <A>]
393+
>>> pickle_unflatten(instances, rest)
394+
{1: <A>, 2: [<A>, <NS>, <A>]}
395+
396+
This can be also used to swap inner objects; the only constraint is that
397+
the number of objects in and out must be the same:
398+
399+
>>> pickle_unflatten(["foo", "bar", "baz"], rest)
400+
{1: "foo", 2: ["bar", <NS>, "baz"]}
401+
"""
402+
instances: list[T] = []
403+
rest: list[object] = []
404+
405+
class Pickler(pickle.Pickler): # numpydoc ignore=GL08
406+
"""
407+
Use the `pickle.Pickler.persistent_id` hook to extract objects.
408+
"""
409+
410+
@override
411+
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
412+
if isinstance(obj, cls):
413+
instances.append(obj) # type: ignore[arg-type]
414+
return 0
415+
416+
typ_ = type(obj)
417+
if typ_ in _BASIC_PICKLED_TYPES: # No subclasses!
418+
# If obj is a collection, recursively descend inside it
419+
return None
420+
if typ_ in _BASIC_REST_TYPES:
421+
rest.append(obj)
422+
return 1
423+
424+
try:
425+
# Note: a class that defines __slots__ without defining __getstate__
426+
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
427+
_ = obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL)
428+
except Exception: # pylint: disable=broad-exception-caught
429+
rest.append(obj)
430+
return 1
431+
432+
# Object can be pickled. Let the Pickler recursively descend inside it.
433+
return None
434+
435+
f = io.BytesIO()
436+
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL)
437+
p.dump(obj)
438+
return instances, (f.getvalue(), *rest)
439+
440+
441+
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
442+
"""
443+
Reverse of ``pickle_flatten``.
444+
445+
Parameters
446+
----------
447+
instances : Iterable
448+
Inner objects to be reinserted into the flattened container.
449+
rest : FlattenRest
450+
Extra bits, as returned by ``pickle_flatten``.
451+
452+
Returns
453+
-------
454+
object
455+
The outer object originally passed to ``pickle_flatten`` after a
456+
pickle->unpickle round-trip.
457+
458+
See Also
459+
--------
460+
pickle_flatten : Serializing function.
461+
pickle.loads : Standard unpickle function.
462+
463+
Notes
464+
-----
465+
The `instances` iterable must yield at least the same number of elements as the ones
466+
returned by ``pickle_without``, but the elements do not need to be the same objects
467+
or even the same types of objects. Excess elements, if any, will be left untouched.
468+
"""
469+
iters = iter(instances), iter(rest)
470+
pik = cast(bytes, next(iters[1]))
471+
472+
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08
473+
"""Mirror of the overridden Pickler in pickle_flatten."""
474+
475+
@override
476+
def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
477+
try:
478+
return next(iters[pid])
479+
except StopIteration as e:
480+
msg = "Not enough objects to unpickle"
481+
raise ValueError(msg) from e
482+
483+
f = io.BytesIO(pik)
484+
return Unpickler(f).load()
485+
486+
487+
class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
488+
"""
489+
Helper of :func:`jax_autojit`.
490+
491+
Wrap arbitrary inputs and outputs of the jitted function and
492+
convert them to/from PyTrees.
493+
"""
494+
495+
obj: T
496+
_registered: ClassVar[bool] = False
497+
__slots__: tuple[str, ...] = ("obj",)
498+
499+
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
500+
self._register()
501+
self.obj = obj
502+
503+
@classmethod
504+
def _register(cls): # numpydoc ignore=SS06
505+
"""
506+
Register upon first use instead of at import time, to avoid
507+
globally importing JAX.
508+
"""
509+
if not cls._registered:
510+
import jax
511+
512+
jax.tree_util.register_pytree_node(
513+
cls,
514+
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
515+
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
516+
)
517+
cls._registered = True
518+
519+
520+
def jax_autojit(
521+
func: Callable[P, T],
522+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
523+
"""
524+
Wrap `func` with ``jax.jit``, with the following differences:
525+
526+
- Python scalar arguments and return values are not automatically converted to
527+
``jax.Array`` objects.
528+
- All non-array arguments are automatically treated as static.
529+
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
530+
``pickle``.
531+
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
532+
tuple/list/dict, but can be any object serializable with ``pickle``.
533+
- Automatically descend into non-array arguments and find ``jax.Array`` objects
534+
inside them, then rebuild the arguments when entering `func`, swapping the JAX
535+
concrete arrays with tracer objects.
536+
- Automatically descend into non-array return values and find ``jax.Array`` objects
537+
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
538+
tracer objects with concrete arrays.
539+
540+
See Also
541+
--------
542+
jax.jit : JAX JIT compilation function.
543+
"""
544+
import jax
545+
546+
@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
547+
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
548+
wargs: _AutoJITWrapper[Any],
549+
) -> _AutoJITWrapper[T]:
550+
args, kwargs = wargs.obj
551+
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
552+
return _AutoJITWrapper(res)
553+
554+
@wraps(func)
555+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
556+
wargs = _AutoJITWrapper((args, kwargs))
557+
return inner(wargs).obj
558+
559+
return outer

0 commit comments

Comments
 (0)