|
| 1 | +import jax |
| 2 | +import pytest |
| 3 | + |
| 4 | +from crazyflow.utils import enable_cache |
| 5 | + |
| 6 | + |
| 7 | +@pytest.mark.unit |
| 8 | +@pytest.mark.parametrize("enable_xla", [True, False]) |
| 9 | +def test_enable_cache(enable_xla: bool): |
| 10 | + """Test that enable_cache correctly sets JAX cache configuration.""" |
| 11 | + # Store original config values |
| 12 | + orig_cache_dir = jax.config.values.get("jax_compilation_cache_dir", None) |
| 13 | + orig_min_size = jax.config.values.get("jax_persistent_cache_min_entry_size_bytes", None) |
| 14 | + orig_min_time = jax.config.values.get("jax_persistent_cache_min_compile_time_secs", None) |
| 15 | + orig_xla = jax.config.values.get("jax_persistent_cache_enable_xla_caches", None) |
| 16 | + |
| 17 | + try: |
| 18 | + cache_path = "/tmp/jax_cache" |
| 19 | + min_size = 1000 |
| 20 | + min_time = 2 |
| 21 | + |
| 22 | + enable_cache( |
| 23 | + cache_path=cache_path, |
| 24 | + min_entry_size_bytes=min_size, |
| 25 | + min_compile_time_secs=min_time, |
| 26 | + enable_xla_caches=enable_xla, |
| 27 | + ) |
| 28 | + |
| 29 | + assert cache_path == jax.config.jax_compilation_cache_dir, "Cache path not set correctly" |
| 30 | + assert ( |
| 31 | + min_size == jax.config.jax_persistent_cache_min_entry_size_bytes |
| 32 | + ), "Min size not set correctly" |
| 33 | + assert ( |
| 34 | + min_time == jax.config.jax_persistent_cache_min_compile_time_secs |
| 35 | + ), "Min time not set correctly" |
| 36 | + expected_xla = "all" if enable_xla else orig_xla |
| 37 | + assert ( |
| 38 | + expected_xla == jax.config.jax_persistent_cache_enable_xla_caches |
| 39 | + ), "XLA caches not set correctly" |
| 40 | + |
| 41 | + finally: |
| 42 | + if orig_cache_dir is not None: |
| 43 | + jax.config.update("jax_compilation_cache_dir", orig_cache_dir) |
| 44 | + if orig_min_size is not None: |
| 45 | + jax.config.update("jax_persistent_cache_min_entry_size_bytes", orig_min_size) |
| 46 | + if orig_min_time is not None: |
| 47 | + jax.config.update("jax_persistent_cache_min_compile_time_secs", orig_min_time) |
| 48 | + if orig_xla is not None: |
| 49 | + jax.config.update("jax_persistent_cache_enable_xla_caches", orig_xla) |
0 commit comments