Skip to content

Commit c4f517e

Browse files
committed
Small update to keep elastic training working
1 parent f52c26c commit c4f517e

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

MaxText/elastic_train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def elastic_handler(
125125
)
126126
with mesh:
127127
data_iterator, _ = create_data_iterator(config, mesh)
128+
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
128129

129130
step, snapshot_jax_arrays, _ = elastic_manager.get_resharded_snapshot(mesh)
130131

@@ -187,6 +188,7 @@ def elastic_handler(
187188
learning_rate_schedule,
188189
metric_logger,
189190
writer,
191+
input_data_shardings,
190192
)
191193

192194

@@ -348,6 +350,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
348350
learning_rate_schedule,
349351
metric_logger,
350352
writer,
353+
input_data_shardings,
351354
) = ret
352355

353356
if step == start_step:
@@ -378,6 +381,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
378381
learning_rate_schedule,
379382
metric_logger,
380383
writer,
384+
input_data_shardings,
381385
) = ret
382386

383387
if checkpoint_manager is not None:

0 commit comments

Comments
 (0)