13
13
# limitations under the License.
14
14
15
15
16
+ import os
16
17
from absl import flags
18
+ import jax
17
19
from jetstream_pt .environment import QuantizationConfig
18
20
19
21
FLAGS = flags .FLAGS
154
156
"page size per page" ,
155
157
)
156
158
flags .DEFINE_string (
157
- "jax_compilation_cache_dir " ,
159
+ "internal_jax_compilation_cache_dir " ,
158
160
"~/jax_cache" ,
159
161
"Jax compilation cache directory" ,
160
162
)
161
163
flags .DEFINE_integer (
162
- "jax_persistent_cache_min_entry_size_bytes " ,
164
+ "internal_jax_persistent_cache_min_entry_size_bytes " ,
163
165
0 ,
164
166
"Minimum size (in bytes) of an entry that will be cached in the persistent compilation cache" ,
165
167
)
166
168
flags .DEFINE_integer (
167
- "jax_persistent_cache_min_compile_time_secs " ,
169
+ "internal_jax_persistent_cache_min_compile_time_secs " ,
168
170
1 ,
169
171
"Minimum compilation time for a computation to be written to persistent cache" ,
170
172
)
@@ -190,3 +192,19 @@ def create_quantization_config_from_flags():
190
192
else FLAGS .quantize_weights
191
193
)
192
194
return config
195
+
196
+
197
+ def set_jax_compilation_cache_config ():
198
+ """Sets the jax compilation cache configuration"""
199
+ jax .config .update (
200
+ "jax_compilation_cache_dir" ,
201
+ os .path .expanduser (FLAGS .internal_jax_compilation_cache_dir ),
202
+ )
203
+ jax .config .update (
204
+ "jax_persistent_cache_min_entry_size_bytes" ,
205
+ FLAGS .internal_jax_persistent_cache_min_entry_size_bytes ,
206
+ )
207
+ jax .config .update (
208
+ "jax_persistent_cache_min_compile_time_secs" ,
209
+ FLAGS .internal_jax_persistent_cache_min_compile_time_secs ,
210
+ )
0 commit comments