@@ -463,7 +463,7 @@ def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: #
463
463
Notes
464
464
-----
465
465
The `instances` iterable must yield at least the same number of elements as the ones
466
- returned by ``pickle_without ``, but the elements do not need to be the same objects
466
+ returned by ``pickle_flatten ``, but the elements do not need to be the same objects
467
467
or even the same types of objects. Excess elements, if any, will be left untouched.
468
468
"""
469
469
iters = iter (instances ), iter (rest )
@@ -540,6 +540,25 @@ def jax_autojit(
540
540
See Also
541
541
--------
542
542
jax.jit : JAX JIT compilation function.
543
+
544
+ Notes
545
+ -----
546
+ These are useful choices *for testing purposes only*, which is how this function is
547
+ intended to be used. The output of ``jax.jit`` is a C++ level callable, that
548
+ directly dispatches to the compiled kernel after the initial call. In comparison,
549
+ ``jax_autojit`` incurs a much higher dispatch time.
550
+
551
+ Additionally, consider::
552
+
553
+ def f(x: Array, y: float, plus: bool) -> Array:
554
+ return x + y if plus else x - y
555
+
556
+ j1 = jax.jit(f, static_argnames="plus")
557
+ j2 = jax_autojit(f)
558
+
559
+ In the above example, ``j2`` requires a lot less setup to be tested effectively than
560
+ ``j1``, but on the flip side it means that it will be re-traced for every different
561
+ value of ``y``, which likely makes it not fit for purpose in production.
543
562
"""
544
563
import jax
545
564
0 commit comments