Skip to content

Commit 3fb722d

Browse files
committed
revert changes to submission runner for prng key
1 parent d57dec3 commit 3fb722d

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

submission_runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,7 @@ def train_once(
214214
_reset_cuda_mem()
215215
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
216216

217-
if FLAGS.framework == 'jax':
218-
data_rng = jax.random.key_data(data_rng)
217+
data_rng = jax.random.key_data(data_rng)
219218
# Workload setup.
220219
logging.info('Initializing dataset.')
221220
if hasattr(workload, '_eval_num_workers'):
@@ -348,8 +347,7 @@ def train_once(
348347
data_select_rng, update_rng, prep_eval_rng, eval_rng = \
349348
prng.split(step_rng, 4)
350349

351-
if FLAGS.framework == 'jax':
352-
eval_rng = jax.random.key_data(eval_rng)
350+
eval_rng = jax.random.key_data(eval_rng)
353351

354352
with profiler.profile('Data selection'):
355353
batch = data_selection(workload,

0 commit comments

Comments
 (0)