Skip to content

Commit f7205e1

Browse files
committed
Accelerate testing. Add jax cache helper.
1 parent 0f3dc63 commit f7205e1

File tree

7 files changed

+131
-5
lines changed

7 files changed

+131
-5
lines changed

crazyflow/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from functools import partial
4+
from pathlib import Path
45
from typing import TypeVar
56

67
import jax
@@ -99,3 +100,17 @@ def _add_marker_to_scene(self: BaseRender, marker: dict):
99100
self.scn.ngeom += 1
100101

101102
BaseRender._add_marker_to_scene = _add_marker_to_scene
103+
104+
105+
def enable_cache(
106+
cache_path: Path = Path("/tmp/jax_cache"),
107+
min_entry_size_bytes: int = -1,
108+
min_compile_time_secs: int = 0,
109+
enable_xla_caches: bool = False,
110+
):
111+
"""Enable JAX cache."""
112+
jax.config.update("jax_compilation_cache_dir", str(cache_path))
113+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", min_entry_size_bytes)
114+
jax.config.update("jax_persistent_cache_min_compile_time_secs", min_compile_time_secs)
115+
if enable_xla_caches:
116+
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

examples/cache.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""This example demonstrates persistent JAX function caching with the Sim class.
2+
3+
When JAX functions decorated with @jit are first called, JAX traces and compiles them to XLA. This
4+
compilation is expensive but only needs to happen once. Subsequent calls reuse the compiled version.
5+
However, the cache is not persistent between Python sessions.
6+
7+
The Sim class uses many jitted functions internally, particularly in the step() method which
8+
compiles a chain of physics and control functions. On the first step() call, the entire chain is
9+
compiled.
10+
11+
After the Python session ends, the cached functions get lost. However, we can enable a persistent
12+
cache that is used when the function we are jit-compiling has been compiled before, which
13+
significantly speeds up the jit compile time at the first step.
14+
15+
By enabling caching:
16+
1. The first run compiles and caches all jitted functions persistently
17+
2. The second run loads the cached compiled functions instead of recompiling
18+
3. This gives a significant speedup in initialization and first step times
19+
20+
The cache persists between Python sessions, so compilation only needs to happen once on a machine,
21+
or when the cache directory is deleted.
22+
"""
23+
24+
import time
25+
from pathlib import Path
26+
27+
from crazyflow.sim import Sim
28+
from crazyflow.utils import enable_cache
29+
30+
31+
def main():
32+
cache_dir = Path("/tmp/jax_cache_test")
33+
if use_cache := cache_dir.exists():
34+
print("Cache directory exists. This run will be fast.")
35+
print("\nTo run without cache, delete the directory.")
36+
else:
37+
print("Cache directory does not exist. This run will be slow.")
38+
print("\nTo run with cache, run this script again.")
39+
enable_cache(cache_path=cache_dir)
40+
t0 = time.perf_counter()
41+
sim = Sim()
42+
t1 = time.perf_counter()
43+
sim.step()
44+
t2 = time.perf_counter()
45+
prefix = "Using cache: " if use_cache else "Not using cache: "
46+
print(f"{prefix}\n Init: {t1 - t0:.3f}s\n Step: {t2 - t1:.3f}s")
47+
48+
49+
if __name__ == "__main__":
50+
main()

examples/spiral.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def control(start_xy: np.ndarray, t: float) -> np.ndarray:
1313

1414

1515
def main():
16-
sim = Sim(n_drones=4, control=Control.state, integrator="rk4", physics="mujoco")
16+
sim = Sim(n_drones=4, control=Control.state, integrator="rk4", physics="analytical")
1717
sim.reset()
1818
duration = 5.0
1919
fps = 60

tests/conftest.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import jax
2+
3+
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
4+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
5+
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
6+
# Do not enable XLA caches, crashes PyTest
7+
# jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

tests/integration/test_examples.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66
import pytest
77

88
EXAMPLES_DIR = Path(__file__).resolve().parent.parent.parent / "examples"
9-
example_scripts = list(sorted(EXAMPLES_DIR.glob("*.py")))
9+
# Make example_scripts a list of strings instead of Path objects so that pytest can use it in its
10+
# automatic printouts. We convert the elements back to Paths in the test function.
11+
example_scripts = [str(p) for p in sorted(EXAMPLES_DIR.glob("*.py"))]
1012

1113

12-
@pytest.mark.parametrize("example_script", example_scripts)
14+
@pytest.mark.parametrize("example_script", [str(p) for p in example_scripts])
1315
@pytest.mark.timeout(60)
1416
@pytest.mark.integration
15-
def test_example_main(example_script: Path):
17+
def test_example_main(example_script: str):
1618
"""Dynamically import and execute the main function from an example script."""
1719
# Add the examples directory to sys.path to resolve imports
20+
example_script = Path(example_script)
1821
sys.path.insert(0, str(EXAMPLES_DIR))
1922

2023
# Dynamically import the module

tests/integration/test_interfaces.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import jax
12
import numpy as np
23
import pytest
34
from scipy.spatial.transform import Rotation as R
@@ -31,14 +32,15 @@ def test_state_interface(physics: Physics):
3132
def test_attitude_interface(physics: Physics):
3233
sim = Sim(physics=physics, control=Control.attitude)
3334
target_pos = np.array([0.0, 0.0, 1.0])
35+
jit_state2attitude = jax.jit(state2attitude)
3436

3537
i_error = np.zeros((1, 1, 3))
3638

3739
for _ in range(int(2 * sim.control_freq)): # Run simulation for 2 seconds
3840
pos, vel, quat = sim.data.states.pos, sim.data.states.vel, sim.data.states.quat
3941
des_pos = np.array([[[0, 0, 1.0]]])
4042
dt = 1 / sim.data.controls.attitude_freq
41-
cmd, i_error = state2attitude(
43+
cmd, i_error = jit_state2attitude(
4244
pos, vel, quat, des_pos, np.zeros((1, 1, 3)), np.zeros((1, 1, 1)), i_error, dt
4345
)
4446
sim.attitude_control(cmd)

tests/unit/test_utils.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import jax
2+
import pytest
3+
4+
from crazyflow.utils import enable_cache
5+
6+
7+
@pytest.mark.unit
8+
@pytest.mark.parametrize("enable_xla", [True, False])
9+
def test_enable_cache(enable_xla: bool):
10+
"""Test that enable_cache correctly sets JAX cache configuration."""
11+
# Store original config values
12+
orig_cache_dir = jax.config.values.get("jax_compilation_cache_dir", None)
13+
orig_min_size = jax.config.values.get("jax_persistent_cache_min_entry_size_bytes", None)
14+
orig_min_time = jax.config.values.get("jax_persistent_cache_min_compile_time_secs", None)
15+
orig_xla = jax.config.values.get("jax_persistent_cache_enable_xla_caches", None)
16+
17+
try:
18+
cache_path = "/tmp/jax_cache"
19+
min_size = 1000
20+
min_time = 2
21+
22+
enable_cache(
23+
cache_path=cache_path,
24+
min_entry_size_bytes=min_size,
25+
min_compile_time_secs=min_time,
26+
enable_xla_caches=enable_xla,
27+
)
28+
29+
assert cache_path == jax.config.jax_compilation_cache_dir, "Cache path not set correctly"
30+
assert (
31+
min_size == jax.config.jax_persistent_cache_min_entry_size_bytes
32+
), "Min size not set correctly"
33+
assert (
34+
min_time == jax.config.jax_persistent_cache_min_compile_time_secs
35+
), "Min time not set correctly"
36+
expected_xla = "all" if enable_xla else orig_xla
37+
assert (
38+
expected_xla == jax.config.jax_persistent_cache_enable_xla_caches
39+
), "XLA caches not set correctly"
40+
41+
finally:
42+
if orig_cache_dir is not None:
43+
jax.config.update("jax_compilation_cache_dir", orig_cache_dir)
44+
if orig_min_size is not None:
45+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", orig_min_size)
46+
if orig_min_time is not None:
47+
jax.config.update("jax_persistent_cache_min_compile_time_secs", orig_min_time)
48+
if orig_xla is not None:
49+
jax.config.update("jax_persistent_cache_enable_xla_caches", orig_xla)

0 commit comments

Comments
 (0)