Skip to content

Commit d57dec3

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/python_upgrades' into python_upgrades
2 parents 5715618 + 7d580f1 commit d57dec3

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

submission_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def train_once(
213213
) -> Tuple[spec.Timing, Dict[str, Any]]:
214214
_reset_cuda_mem()
215215
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
216-
data_rng = jax.random.key_data(data_rng)
216+
217+
if FLAGS.framework == 'jax':
218+
data_rng = jax.random.key_data(data_rng)
217219
# Workload setup.
218220
logging.info('Initializing dataset.')
219221
if hasattr(workload, '_eval_num_workers'):
@@ -345,7 +347,9 @@ def train_once(
345347

346348
data_select_rng, update_rng, prep_eval_rng, eval_rng = \
347349
prng.split(step_rng, 4)
348-
eval_rng = jax.random.key_data(eval_rng)
350+
351+
if FLAGS.framework == 'jax':
352+
eval_rng = jax.random.key_data(eval_rng)
349353

350354
with profiler.profile('Data selection'):
351355
batch = data_selection(workload,

0 commit comments

Comments
 (0)