Skip to content

Commit 5551f52

Browse files
author
maxtext authors
committed
Merge pull request #1862 from AI-Hypercomputer:data_loader
PiperOrigin-RevId: 776212972
2 parents 7cee948 + 34229f5 commit 5551f52

File tree

8 files changed

+414
-249
lines changed

8 files changed

+414
-249
lines changed

MaxText/data_loader.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pytype: disable=unsupported-operands
16+
"""Module to load data for training."""
17+
18+
import jax
19+
import jax.numpy as jnp
20+
from jax.experimental import checkify
21+
22+
from MaxText import exceptions
23+
from MaxText import maxtext_utils
24+
from MaxText.utils.goodput_utils import (
25+
GoodputEvent,
26+
maybe_record_goodput,
27+
)
28+
29+
30+
class DataLoader:
31+
"""
32+
Loads preprocessed data for training.
33+
"""
34+
35+
def __init__(self, config, mesh, data_iterator, goodput_recorder):
36+
self.config = config
37+
self.goodput_recorder = goodput_recorder
38+
self.data_iterator = data_iterator
39+
self.last_batch = None
40+
self.input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
41+
42+
def load_next_batch(self):
43+
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
44+
with maybe_record_goodput(self.goodput_recorder, GoodputEvent.DATA_LOADING):
45+
try:
46+
if self.config.reuse_example_batch and self.last_batch:
47+
example_batch = self.last_batch
48+
else:
49+
example_batch = next(self.data_iterator)
50+
# Reshard data from loaded sharding to performant activation sharding
51+
self.last_batch = jax.lax.with_sharding_constraint(example_batch, self.input_data_shardings)
52+
self.check_example_batch()
53+
except Exception as e: # pylint: disable=broad-except
54+
if "StopIteration" in str(e):
55+
raise exceptions.StopTraining("You may have run out of training data.")
56+
else:
57+
raise exceptions.StopTraining("`load_next_batch()` failed.")
58+
return self.last_batch
59+
60+
def check_example_batch(self):
61+
if self.config.max_checkify:
62+
jittable_f = checkify.checkify(lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!"))
63+
# Check if inputs in batch contains bad synthetic data.
64+
# pylint: disable=not-callable
65+
err, _ = jax.jit(jittable_f)(self.last_batch["inputs"][: self.config.global_batch_size_to_train_on, :])
66+
err.throw()

MaxText/elastic_train.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,16 @@
6262
import tensorflow as tf
6363

6464
from MaxText import checkpointing
65+
from MaxText import exceptions
6566
from MaxText import max_utils
6667
from MaxText import maxtext_utils
6768
from MaxText import max_logging
6869
from MaxText import profiler
6970
from MaxText import pyconfig
71+
from MaxText.data_loader import DataLoader
7072
from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator
7173
from MaxText.metric_logger import MetricLogger
72-
from MaxText.train import check_example_batch
7374
from MaxText.train import get_first_step
74-
from MaxText.train import load_next_batch
7575
from MaxText.train import save_checkpoint
7676
from MaxText.train import setup_mesh_and_model
7777
from MaxText.train import setup_train_loop
@@ -233,8 +233,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
233233
block=True,
234234
)
235235

236-
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
237-
236+
data_loader = DataLoader(config, mesh, data_iterator, recorder)
238237
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
239238

240239
# Write train config params, num model params, and XLA flags to tensorboard
@@ -250,15 +249,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
250249
max_logging.log(f"{step=} {elastic_manager.elastic_down_event_count=} {elastic_manager.good_slice_count=}")
251250
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules), jax.default_device(elastic_manager.default_device):
252251
with jax.profiler.StepTraceAnnotation("train", step_num=step):
253-
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
254-
try:
255-
example_batch = load_next_batch(data_iterator, example_batch, config)
256-
example_batch = jax.lax.with_sharding_constraint(example_batch, input_data_shardings)
257-
except Exception as e: # pylint: disable=broad-except
258-
max_logging.log(f"load_next_batch failed, you may have run out of data. Error message: {e}")
259-
break
260-
261-
check_example_batch(config, example_batch=example_batch)
252+
example_batch = data_loader.load_next_batch()
262253
# pylint: disable=not-callable
263254
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
264255
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
@@ -345,6 +336,8 @@ def train_loop(config, elastic_manager, recorder, state=None):
345336
learning_rate_schedule,
346337
metric_logger,
347338
) = ret
339+
except exceptions.StopTraining as error:
340+
max_logging.log(f"Training stopped: {str(error)}")
348341

349342
if checkpoint_manager is not None:
350343
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):

MaxText/exceptions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Custom exceptions for MaxText."""
16+
17+
18+
class StopTraining(Exception):
19+
"""Custom exception to halt a training process."""
20+
21+
def __init__(self, reason):
22+
super().__init__(reason)

MaxText/experimental/rl/grpo_trainer.py

Lines changed: 67 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,23 @@
5050
import transformers
5151

5252
from MaxText import checkpointing
53+
from MaxText import exceptions
5354
from MaxText import max_logging
5455
from MaxText import max_utils
5556
from MaxText import maxengine
5657
from MaxText import maxtext_utils
5758
from MaxText import profiler
5859
from MaxText import pyconfig
5960
from MaxText.common_types import Array
61+
from MaxText.data_loader import DataLoader
6062
from MaxText.experimental.rl import grpo_input_pipeline
61-
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
6263
from MaxText.globals import EPS
6364
from MaxText.layers import models
6465
from MaxText.metric_logger import MetricLogger
6566
from MaxText.train import (
6667
validate_train_config,
6768
get_first_step,
68-
load_next_batch,
6969
save_checkpoint,
70-
check_example_batch,
7170
setup_mesh_and_model,
7271
)
7372
from MaxText.utils.goodput_utils import (
@@ -765,84 +764,77 @@ def train_loop(config, config_inference, recorder, state=None):
765764

766765
example_batch = None
767766

768-
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
769-
767+
data_loader = DataLoader(config, mesh, data_iterator, recorder)
770768
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
771769

772770
# Write train config params, num model params, and XLA flags to tensorboard
773771
metric_logger.write_setup_info_to_tensorboard(state.params)
774772

775-
for step in np.arange(start_step, config.steps):
776-
step_start_time = datetime.datetime.now()
777-
prof.maybe_activate_profiler(step, state)
778-
779-
with jax.profiler.StepTraceAnnotation("train", step_num=step):
780-
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
781-
try:
782-
example_batch = load_next_batch(data_iterator, example_batch, config)
783-
example_batch = jax.lax.with_sharding_constraint(example_batch, input_data_shardings)
784-
except Exception as e: # pylint: disable=broad-except
785-
max_logging.log(f"load_next_batch failed, you may have run out of data. Error message: {e}")
786-
break
787-
788-
check_example_batch(config, example_batch=example_batch)
789-
# pylint: disable=not-callable
790-
rng = jax.jit(jax.random.fold_in)(init_rng, step)
791-
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
792-
rng, rng_gen = random.split(rng)
793-
example_batch = p_generate_completions(example_batch, state.params, rng_gen)
794-
795-
# TODO: ensure this partitioning is correct
796-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
797-
state, metrics = p_train_step(state, example_batch, rng)
798-
799-
if checkpoint_manager is not None:
800-
state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0]
801-
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
802-
checkpointing.print_save_message(step, config.async_checkpointing)
803-
804-
# Upon preemption, exit when and only when all ongoing saves are complete.
805-
if checkpoint_manager.reached_preemption(step):
806-
checkpoint_manager.wait_until_finished()
807-
sys.exit()
808-
809-
if config.dump_hlo and step == start_step:
810-
jax.block_until_ready(state) # Ensure compilation has finished.
811-
max_utils.upload_dump(
812-
config.dump_hlo_local_dir,
813-
config.dump_hlo_gcs_dir,
814-
module_name=config.dump_hlo_module_name,
815-
delete_local_after=config.dump_hlo_delete_local_after,
816-
all_host_upload=config.dump_hlo_upload_all,
817-
)
773+
try:
774+
for step in np.arange(start_step, config.steps):
775+
step_start_time = datetime.datetime.now()
776+
prof.maybe_activate_profiler(step, state)
777+
778+
with jax.profiler.StepTraceAnnotation("train", step_num=step):
779+
example_batch = data_loader.load_next_batch()
780+
# pylint: disable=not-callable
781+
rng = jax.jit(jax.random.fold_in)(init_rng, step)
782+
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
783+
rng, rng_gen = random.split(rng)
784+
example_batch = p_generate_completions(example_batch, state.params, rng_gen)
785+
786+
# TODO: ensure this partitioning is correct
787+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
788+
state, metrics = p_train_step(state, example_batch, rng)
789+
790+
if checkpoint_manager is not None:
791+
state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0]
792+
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
793+
checkpointing.print_save_message(step, config.async_checkpointing)
794+
795+
# Upon preemption, exit when and only when all ongoing saves are complete.
796+
if checkpoint_manager.reached_preemption(step):
797+
checkpoint_manager.wait_until_finished()
798+
sys.exit()
799+
800+
if config.dump_hlo and step == start_step:
801+
jax.block_until_ready(state) # Ensure compilation has finished.
802+
max_utils.upload_dump(
803+
config.dump_hlo_local_dir,
804+
config.dump_hlo_gcs_dir,
805+
module_name=config.dump_hlo_module_name,
806+
delete_local_after=config.dump_hlo_delete_local_after,
807+
all_host_upload=config.dump_hlo_upload_all,
808+
)
818809

819-
if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
820-
assert eval_data_iterator
821-
eval_step_count = 0
822-
# pylint: disable=not-callable
823-
for eval_batch in eval_data_iterator:
824-
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
825-
break
826-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
827-
eval_metrics = p_eval_step(state, eval_batch, rng)
828-
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
829-
max_logging.log(f"Completed eval step {eval_step_count}")
830-
eval_step_count += 1
831-
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
832-
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
833-
max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}")
834-
prof.deactivate()
835-
break
836-
837-
prof.maybe_deactivate_profiler(step, state)
838-
839-
if step == start_step:
840-
max_utils.print_mem_stats("After params initialized")
841-
842-
jax.block_until_ready(state) # ensure training step is completed
843-
844-
step_time_delta = datetime.datetime.now() - step_start_time
845-
metric_logger.record_train_metrics(metrics, step, step_time_delta)
810+
if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
811+
assert eval_data_iterator
812+
eval_step_count = 0
813+
# pylint: disable=not-callable
814+
for eval_batch in eval_data_iterator:
815+
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
816+
break
817+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
818+
eval_metrics = p_eval_step(state, eval_batch, rng)
819+
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
820+
max_logging.log(f"Completed eval step {eval_step_count}")
821+
eval_step_count += 1
822+
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
823+
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
824+
prof.deactivate()
825+
raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.")
826+
827+
prof.maybe_deactivate_profiler(step, state)
828+
829+
if step == start_step:
830+
max_utils.print_mem_stats("After params initialized")
831+
832+
jax.block_until_ready(state) # ensure training step is completed
833+
834+
step_time_delta = datetime.datetime.now() - step_start_time
835+
metric_logger.record_train_metrics(metrics, step, step_time_delta)
836+
except exceptions.StopTraining as e:
837+
max_logging.log(f"Training stopped: {str(e)}")
846838

847839
if checkpoint_manager is not None:
848840
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):

0 commit comments

Comments
 (0)