Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/MaxText/experimental/rl/grpo_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from MaxText.input_pipeline import input_pipeline_interface
from MaxText.input_pipeline import _input_pipeline_utils
from MaxText import multihost_dataloading


class SingleHostDataLoader:
Expand Down Expand Up @@ -194,8 +195,8 @@ def lists2array(x):
read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128),
)

# single_host_gen = SingleHostDataLoader(dataloader, global_mesh)
return iter(dataloader)
# Return multi-host data iterator for proper multi-host data loading
return multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)


def make_hf_train_iterator(
Expand Down
44 changes: 42 additions & 2 deletions src/MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@
from MaxText.checkpointing import CheckpointManager
from MaxText.utils import gcs_utils
from MaxText.inference import offline_engine
from MaxText.data_loader import DataLoader
from MaxText.experimental.rl import grpo_input_pipeline
from MaxText.experimental.rl import grpo_utils
from MaxText.experimental.rl.hooks import GRPOTrainingHooks, GRPODataHooks
from MaxText.globals import EPS
from MaxText.metric_logger import MetricLogger
from MaxText.train import get_first_step
Expand Down Expand Up @@ -705,12 +705,26 @@ def train_loop(config, config_inference, recorder, state=None):

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Initialize GRPO training hooks
training_hooks = GRPOTrainingHooks(
config=config, mesh=mesh, learning_rate_schedule=learning_rate_schedule, goodput_recorder=recorder
)

# Initialize GRPO data hooks with multi-host data pipeline
# This replaces the old DataLoader with improved multi-host data loading
data_hooks = GRPODataHooks(config=config, mesh=mesh, goodput_recorder=recorder)

# Use the data_hooks' train_data_loader for loading prompts
data_loader = data_hooks.train_data_loader

# Write train config params, num model params, and XLA flags to tensorboard
metric_logger.write_setup_info_to_tensorboard(state.params["params"])

# Call on_train_start hook
training_hooks.on_train_start(state, start_step)

def generation_worker_fn(
worker_inference_engine,
worker_tokenizer_model,
Expand Down Expand Up @@ -765,6 +779,9 @@ def generation_worker_fn(
inference_engine_lock = threading.Lock()

max_logging.log("Inference Rollout")
# Track initial generation
training_hooks.on_generation_start(start_step)
gen_start_time = time.time()
generate_completions(
data_loader,
inference_engine,
Expand All @@ -776,6 +793,10 @@ def generation_worker_fn(
data_sharding,
inference_engine_lock,
)
gen_time = time.time() - gen_start_time
with data_buffer_lock:
num_completions = sum(batch[config.train_data_columns].shape[0] for batch in data_buffer)
training_hooks.on_generation_end(start_step, num_completions, gen_time)

required_batch_size = int(config.per_device_batch_size * config.num_generations * mesh.size)
generation_thread = threading.Thread(
Expand All @@ -797,7 +818,11 @@ def generation_worker_fn(

try:
last_step_completion = datetime.datetime.now()
step = start_step # Initialize step variable
for step in np.arange(start_step, config.steps):
# Call on_train_step_start hook
training_hooks.on_train_step_start(step)

prof.maybe_activate_profiler(step, state)

with jax.profiler.StepTraceAnnotation("train", step_num=step):
Expand Down Expand Up @@ -838,6 +863,9 @@ def generation_worker_fn(

state_to_save = _split_grpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step)
# Note: maybe_save_checkpoint doesn't return a value, so we check if it's a checkpoint step
if step % config.checkpoint_period == 0:
training_hooks.on_checkpoint_save(step, config.checkpoint_dir)

if config.dump_hlo and step == start_step:
jax.block_until_ready(state) # Ensure compilation has finished.
Expand All @@ -851,17 +879,23 @@ def generation_worker_fn(

if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
assert eval_data_iterator
# Call on_eval_start hook
training_hooks.on_eval_start(step)
eval_step_count = 0
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
if 0 < config.eval_steps <= eval_step_count:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, rng)
# Call on_eval_step hook
training_hooks.on_eval_step(eval_metrics)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
eval_step_count += 1
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
# Call on_eval_end hook
training_hooks.on_eval_end(step)
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
prof.deactivate()
raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.")
Expand All @@ -872,11 +906,17 @@ def generation_worker_fn(
max_utils.print_mem_stats("After params initialized")

metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)

# Call on_train_step_end hook
training_hooks.on_train_step_end(step, metrics, step_time_delta.total_seconds())

state_to_save = _split_grpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
except exceptions.StopTraining as e:
max_logging.log(f"Training stopped: {str(e)}")
finally:
# Call on_train_end hook
training_hooks.on_train_end(step)
metric_logger.flush_metrics_and_cleanup()
max_logging.log("Training loop finished or exited. Signaling generation worker to stop.")
stop_event.set()
Expand Down
Loading
Loading