Skip to content

Commit 4047aa2

Browse files
lukebaumannRoshaniN
authored andcommitted
Small update to keep elastic training working
Adds input_data_shardings to elastic_handler
1 parent 5551f52 commit 4047aa2

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

MaxText/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,4 +745,4 @@ projector_output_dim_for_vit: 4096
745745
rope_theta_for_vit: 10000
746746
vision_output_dim_for_vit: 4096
747747
pixel_shuffle_ratio_for_vit: 0.5
748-
projector_dropout_for_vit: 0.0
748+
projector_dropout_for_vit: 0.0

MaxText/elastic_train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def elastic_handler(
122122
)
123123
with mesh:
124124
data_iterator, _ = create_data_iterator(config, mesh)
125+
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
125126

126127
step, snapshot_jax_arrays, _ = elastic_manager.get_resharded_snapshot(mesh)
127128

@@ -183,6 +184,7 @@ def elastic_handler(
183184
example_batch,
184185
learning_rate_schedule,
185186
metric_logger,
187+
input_data_shardings,
186188
)
187189

188190

@@ -301,6 +303,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
301303
example_batch,
302304
learning_rate_schedule,
303305
metric_logger,
306+
input_data_shardings,
304307
) = ret
305308

306309
if step == start_step:
@@ -335,6 +338,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
335338
example_batch,
336339
learning_rate_schedule,
337340
metric_logger,
341+
input_data_shardings,
338342
) = ret
339343
except exceptions.StopTraining as error:
340344
max_logging.log(f"Training stopped: {str(error)}")

0 commit comments

Comments
 (0)