Skip to content

Commit a12698d

Browse files
authored
Add jax compilation cache config (#198)
1 parent fe22a9f commit a12698d

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

jetstream_pt/cli.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def create_engine(devices):
5555
"""Create Pytorch engine from flags"""
5656
torch.set_default_dtype(torch.bfloat16)
5757
quant_config = config.create_quantization_config_from_flags()
58+
config.set_jax_compilation_cache_config()
5859
env_data = fetch_models.construct_env_data_from_model_id(
5960
FLAGS.model_id,
6061
FLAGS.override_batch_size,

jetstream_pt/config.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515

16+
import os
1617
from absl import flags
18+
import jax
1719
from jetstream_pt.environment import QuantizationConfig
1820

1921
FLAGS = flags.FLAGS
@@ -154,17 +156,17 @@
154156
"page size per page",
155157
)
156158
flags.DEFINE_string(
157-
"jax_compilation_cache_dir",
159+
"internal_jax_compilation_cache_dir",
158160
"~/jax_cache",
159161
"Jax compilation cache directory",
160162
)
161163
flags.DEFINE_integer(
162-
"jax_persistent_cache_min_entry_size_bytes",
164+
"internal_jax_persistent_cache_min_entry_size_bytes",
163165
0,
164166
"Minimum size (in bytes) of an entry that will be cached in the persistent compilation cache",
165167
)
166168
flags.DEFINE_integer(
167-
"jax_persistent_cache_min_compile_time_secs",
169+
"internal_jax_persistent_cache_min_compile_time_secs",
168170
1,
169171
"Minimum compilation time for a computation to be written to persistent cache",
170172
)
@@ -190,3 +192,19 @@ def create_quantization_config_from_flags():
190192
else FLAGS.quantize_weights
191193
)
192194
return config
195+
196+
197+
def set_jax_compilation_cache_config():
198+
"""Sets the jax compilation cache configuration"""
199+
jax.config.update(
200+
"jax_compilation_cache_dir",
201+
os.path.expanduser(FLAGS.internal_jax_compilation_cache_dir),
202+
)
203+
jax.config.update(
204+
"jax_persistent_cache_min_entry_size_bytes",
205+
FLAGS.internal_jax_persistent_cache_min_entry_size_bytes,
206+
)
207+
jax.config.update(
208+
"jax_persistent_cache_min_compile_time_secs",
209+
FLAGS.internal_jax_persistent_cache_min_compile_time_secs,
210+
)

0 commit comments

Comments
 (0)