File tree 2 files changed +4
-1
lines changed
2 files changed +4
-1
lines changed Original file line number Diff line number Diff line change 1
1
from flax import jax_utils
2
+ from flax .core import FrozenDict
2
3
import jax
3
4
import numpy as np
4
5
import torch
@@ -16,6 +17,8 @@ def torch2jax(jax_workload,
16
17
jax_params , model_state = jax_workload .init_model_fn (jax .random .PRNGKey (0 ),
17
18
** init_kwargs )
18
19
pytorch_model , _ = pytorch_workload .init_model_fn ([0 ], ** init_kwargs )
20
+ if isinstance (jax_params , dict ):
21
+ jax_params = FrozenDict (jax_params )
19
22
jax_params = jax_utils .unreplicate (jax_params ).unfreeze ()
20
23
if model_state is not None :
21
24
model_state = jax_utils .unreplicate (model_state )
Original file line number Diff line number Diff line change @@ -50,7 +50,7 @@ def test_workload(self, workload):
50
50
pyt_logs = '/tmp/pyt_log.pkl'
51
51
try :
52
52
run (
53
- f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python3 -m tests.reference_algorithm_tests --workload={ workload } --framework=jax --global_batch_size={ GLOBAL_BATCH_SIZE } --log_file={ jax_logs } '
53
+ f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={ workload } --framework=jax --global_batch_size={ GLOBAL_BATCH_SIZE } --log_file={ jax_logs } '
54
54
f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={ NUM_TRAIN_STEPS } ' ,
55
55
shell = True ,
56
56
stdout = DEVNULL ,
You can’t perform that action at this time.
0 commit comments