Skip to content

Commit 4d8953d

Browse files
authored
DOC: autojit notes (#297)
1 parent 28a364d commit 4d8953d

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: #
463463
Notes
464464
-----
465465
The `instances` iterable must yield at least the same number of elements as the ones
466-
returned by ``pickle_without``, but the elements do not need to be the same objects
466+
returned by ``pickle_flatten``, but the elements do not need to be the same objects
467467
or even the same types of objects. Excess elements, if any, will be left untouched.
468468
"""
469469
iters = iter(instances), iter(rest)
@@ -540,6 +540,25 @@ def jax_autojit(
540540
See Also
541541
--------
542542
jax.jit : JAX JIT compilation function.
543+
544+
Notes
545+
-----
546+
These are useful choices *for testing purposes only*, which is how this function is
547+
intended to be used. The output of ``jax.jit`` is a C++ level callable, that
548+
directly dispatches to the compiled kernel after the initial call. In comparison,
549+
``jax_autojit`` incurs a much higher dispatch time.
550+
551+
Additionally, consider::
552+
553+
def f(x: Array, y: float, plus: bool) -> Array:
554+
return x + y if plus else x - y
555+
556+
j1 = jax.jit(f, static_argnames="plus")
557+
j2 = jax_autojit(f)
558+
559+
In the above example, ``j2`` requires a lot less setup to be tested effectively than
560+
``j1``, but on the flip side it means that it will be re-traced for every different
561+
value of ``y``, which likely makes it not fit for purpose in production.
543562
"""
544563
import jax
545564

src/array_api_extra/testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def lazy_xp_function( # type: ignore[explicit-any]
9696
jax_jit : bool, optional
9797
Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after
9898
calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``.
99+
This is the default behaviour.
99100
Set to False if `func` is only compatible with eager (non-jitted) JAX.
100101
101102
Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX

0 commit comments

Comments
 (0)