Skip to content

Add DataLoader module and introduce custom StopTraining exception #1862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions MaxText/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module to load data for training."""

import jax
import jax.numpy as jnp
from jax.experimental import checkify

from MaxText import exceptions
from MaxText import maxtext_utils
from MaxText.utils.goodput_utils import (
GoodputEvent,
maybe_record_goodput,
)


class DataLoader:
"""
Loads preprocessed data for training.
"""

def __init__(self, config, mesh, data_iterator, goodput_recorder):
self.config = config
self.goodput_recorder = goodput_recorder
self.data_iterator = data_iterator
self.last_batch = None
self.input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

def load_next_batch(self):
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
with maybe_record_goodput(self.goodput_recorder, GoodputEvent.DATA_LOADING):
try:
if self.config.reuse_example_batch and self.last_batch:
example_batch = self.last_batch
else:
example_batch = next(self.data_iterator)
# Reshard data from loaded sharding to performant activation sharding
self.last_batch = jax.lax.with_sharding_constraint(example_batch, self.input_data_shardings)
self.check_example_batch()
except Exception as e: # pylint: disable=broad-except
if "StopIteration" in str(e):
raise exceptions.StopTraining("You may have run out of training data.")
else:
raise exceptions.StopTraining("`load_next_batch()` failed.")
return self.last_batch

def check_example_batch(self):
if self.config.max_checkify:
jittable_f = checkify.checkify(lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!"))
# Check if inputs in batch contains bad synthetic data.
# pylint: disable=not-callable
err, _ = jax.jit(jittable_f)(self.last_batch["inputs"][: self.config.global_batch_size_to_train_on, :])
err.throw()
19 changes: 6 additions & 13 deletions MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@
import tensorflow as tf

from MaxText import checkpointing
from MaxText import exceptions
from MaxText import max_utils
from MaxText import maxtext_utils
from MaxText import max_logging
from MaxText import profiler
from MaxText import pyconfig
from MaxText.data_loader import DataLoader
from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator
from MaxText.metric_logger import MetricLogger
from MaxText.train import check_example_batch
from MaxText.train import get_first_step
from MaxText.train import load_next_batch
from MaxText.train import save_checkpoint
from MaxText.train import setup_mesh_and_model
from MaxText.train import setup_train_loop
Expand Down Expand Up @@ -233,8 +233,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
block=True,
)

input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

data_loader = DataLoader(config, mesh, data_iterator, recorder)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
Expand All @@ -250,15 +249,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
max_logging.log(f"{step=} {elastic_manager.elastic_down_event_count=} {elastic_manager.good_slice_count=}")
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules), jax.default_device(elastic_manager.default_device):
with jax.profiler.StepTraceAnnotation("train", step_num=step):
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
try:
example_batch = load_next_batch(data_iterator, example_batch, config)
example_batch = jax.lax.with_sharding_constraint(example_batch, input_data_shardings)
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"load_next_batch failed, you may have run out of data. Error message: {e}")
break

check_example_batch(config, example_batch=example_batch)
example_batch = data_loader.load_next_batch()
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
Expand Down Expand Up @@ -345,6 +336,8 @@ def train_loop(config, elastic_manager, recorder, state=None):
learning_rate_schedule,
metric_logger,
) = ret
except exceptions.StopTraining as error:
max_logging.log(f"Training stopped: {str(error)}")

if checkpoint_manager is not None:
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):
Expand Down
22 changes: 22 additions & 0 deletions MaxText/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom exceptions for MaxText."""


class StopTraining(Exception):
"""Custom exception to halt a training process."""

def __init__(self, reason):
super().__init__(reason)
142 changes: 67 additions & 75 deletions MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,23 @@
import transformers

from MaxText import checkpointing
from MaxText import exceptions
from MaxText import max_logging
from MaxText import max_utils
from MaxText import maxengine
from MaxText import maxtext_utils
from MaxText import profiler
from MaxText import pyconfig
from MaxText.common_types import Array
from MaxText.data_loader import DataLoader
from MaxText.experimental.rl import grpo_input_pipeline
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
from MaxText.globals import EPS
from MaxText.layers import models
from MaxText.metric_logger import MetricLogger
from MaxText.train import (
validate_train_config,
get_first_step,
load_next_batch,
save_checkpoint,
check_example_batch,
setup_mesh_and_model,
)
from MaxText.utils.goodput_utils import (
Expand Down Expand Up @@ -765,84 +764,77 @@ def train_loop(config, config_inference, recorder, state=None):

example_batch = None

input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

data_loader = DataLoader(config, mesh, data_iterator, recorder)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
metric_logger.write_setup_info_to_tensorboard(state.params)

for step in np.arange(start_step, config.steps):
step_start_time = datetime.datetime.now()
prof.maybe_activate_profiler(step, state)

with jax.profiler.StepTraceAnnotation("train", step_num=step):
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
try:
example_batch = load_next_batch(data_iterator, example_batch, config)
example_batch = jax.lax.with_sharding_constraint(example_batch, input_data_shardings)
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"load_next_batch failed, you may have run out of data. Error message: {e}")
break

check_example_batch(config, example_batch=example_batch)
# pylint: disable=not-callable
rng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
rng, rng_gen = random.split(rng)
example_batch = p_generate_completions(example_batch, state.params, rng_gen)

# TODO: ensure this partitioning is correct
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, rng)

if checkpoint_manager is not None:
state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0]
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
checkpointing.print_save_message(step, config.async_checkpointing)

# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(step):
checkpoint_manager.wait_until_finished()
sys.exit()

if config.dump_hlo and step == start_step:
jax.block_until_ready(state) # Ensure compilation has finished.
max_utils.upload_dump(
config.dump_hlo_local_dir,
config.dump_hlo_gcs_dir,
module_name=config.dump_hlo_module_name,
delete_local_after=config.dump_hlo_delete_local_after,
all_host_upload=config.dump_hlo_upload_all,
)
try:
for step in np.arange(start_step, config.steps):
step_start_time = datetime.datetime.now()
prof.maybe_activate_profiler(step, state)

with jax.profiler.StepTraceAnnotation("train", step_num=step):
example_batch = data_loader.load_next_batch()
# pylint: disable=not-callable
rng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
rng, rng_gen = random.split(rng)
example_batch = p_generate_completions(example_batch, state.params, rng_gen)

# TODO: ensure this partitioning is correct
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, rng)

if checkpoint_manager is not None:
state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0]
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
checkpointing.print_save_message(step, config.async_checkpointing)

# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(step):
checkpoint_manager.wait_until_finished()
sys.exit()

if config.dump_hlo and step == start_step:
jax.block_until_ready(state) # Ensure compilation has finished.
max_utils.upload_dump(
config.dump_hlo_local_dir,
config.dump_hlo_gcs_dir,
module_name=config.dump_hlo_module_name,
delete_local_after=config.dump_hlo_delete_local_after,
all_host_upload=config.dump_hlo_upload_all,
)

if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
assert eval_data_iterator
eval_step_count = 0
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, rng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
eval_step_count += 1
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}")
prof.deactivate()
break

prof.maybe_deactivate_profiler(step, state)

if step == start_step:
max_utils.print_mem_stats("After params initialized")

jax.block_until_ready(state) # ensure training step is completed

step_time_delta = datetime.datetime.now() - step_start_time
metric_logger.record_train_metrics(metrics, step, step_time_delta)
if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
assert eval_data_iterator
eval_step_count = 0
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, rng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
eval_step_count += 1
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
prof.deactivate()
raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.")

prof.maybe_deactivate_profiler(step, state)

if step == start_step:
max_utils.print_mem_stats("After params initialized")

jax.block_until_ready(state) # ensure training step is completed

step_time_delta = datetime.datetime.now() - step_start_time
metric_logger.record_train_metrics(metrics, step, step_time_delta)
except exceptions.StopTraining as e:
max_logging.log(f"Training stopped: {str(e)}")

if checkpoint_manager is not None:
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):
Expand Down
Loading
Loading