Skip to content

Commit 2198d88

Browse files
committed
Hooks
Signed-off-by: Vladimir Suvorov <[email protected]>
1 parent 72cc320 commit 2198d88

File tree

3 files changed

+415
-80
lines changed

3 files changed

+415
-80
lines changed

src/MaxText/examples/grpo_llama3_demo.py

Lines changed: 16 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,7 @@
1717
This tutorial demonstrates training the Llama3.1 8B-IT model on
1818
the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO).
1919
GRPO can enhance your model's problem-solving skills on mathematical word problems,
20-
coding problems, etc.
21-
22-
GOODPUT MONITORING FEATURES:
23-
- Automatic goodput measurement and tracking
24-
- Badput breakdown analysis (non-productive time tracking)
25-
- Step time deviation monitoring
26-
- TensorBoard and Google Cloud Monitoring integration
27-
- Performance metrics upload to GCM
28-
- Real-time training efficiency monitoring
29-
"""
20+
coding problems, etc. """
3021

3122
# This tutorial demonstrates training the Llama3.1 8B-IT model on the GSM8K math
3223
# reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can
@@ -89,10 +80,6 @@
8980
from MaxText import pyconfig
9081
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
9182

92-
# MaxText goodput monitoring imports
93-
from MaxText.utils.goodput_utils import maybe_monitor_goodput, create_goodput_recorder, maybe_record_goodput, GoodputEvent
94-
from MaxText import max_logging
95-
9683
# This is for running the script in a colab or notebook environment.
9784
# import nest_asyncio
9885
# nest_asyncio.apply() # To fix "This event loop is already running" error in Colab
@@ -144,14 +131,6 @@
144131
# ====== Reproducibility ======
145132
SEED = 42
146133

147-
# ====== Goodput Monitoring ======
148-
# Enable goodput monitoring for performance tracking
149-
ENABLE_GOODPUT_RECORDING = True
150-
MONITOR_GOODPUT = True
151-
ENABLE_GCP_GOODPUT_METRICS = True
152-
ENABLE_GCP_STEP_DEVIATION_METRICS = True
153-
GOODPUT_UPLOAD_INTERVAL_SECONDS = 30
154-
155134

156135
# ====== GRPO ======
157136
# === Generation during GRPO training ===
@@ -929,30 +908,6 @@ def evaluate(
929908
# Let's set up all the configs first - checkpointing, metric logging and training.
930909
# We then train the model.
931910
def main():
932-
# Create a mock config object for goodput monitoring
933-
class MockConfig:
934-
935-
def __init__(self):
936-
self.monitor_goodput = MONITOR_GOODPUT
937-
self.enable_goodput_recording = ENABLE_GOODPUT_RECORDING
938-
self.enable_gcp_goodput_metrics = ENABLE_GCP_GOODPUT_METRICS
939-
self.enable_gcp_step_deviation_metrics = ENABLE_GCP_STEP_DEVIATION_METRICS
940-
self.goodput_upload_interval_seconds = GOODPUT_UPLOAD_INTERVAL_SECONDS
941-
self.run_name = "grpo_llama3_demo"
942-
self.tensorboard_dir = LOG_DIR
943-
self.enable_pathways_goodput = False
944-
self.monitor_step_time_deviation = True
945-
self.step_deviation_interval_seconds = 60
946-
self.report_performance_metric_for_gcp_monitoring = False
947-
948-
config = MockConfig()
949-
950-
# Initialize goodput monitoring
951-
maybe_monitor_goodput(config)
952-
recorder = create_goodput_recorder(config)
953-
954-
max_logging.log("GRPO training with goodput monitoring started")
955-
956911
# Ckpt saving
957912
checkpointing_options = ocp.CheckpointManagerOptions(save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP)
958913

@@ -1057,55 +1012,37 @@ def __init__(self):
10571012

10581013
# ## Evaluate before training
10591014
#
1060-
max_logging.log("Starting pre-training evaluation...")
1061-
1062-
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
1063-
# pylint: disable=unbalanced-tuple-unpacking
1064-
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
1065-
test_dataset,
1066-
rl_cluster,
1067-
**GENERATION_CONFIGS["greedy"],
1068-
)
1069-
print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
10701015

1071-
max_logging.log(f"Pre-training evaluation completed: {accuracy}% accuracy")
1016+
# pylint: disable=unbalanced-tuple-unpacking
1017+
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
1018+
test_dataset,
1019+
rl_cluster,
1020+
**GENERATION_CONFIGS["greedy"],
1021+
)
1022+
print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
10721023

10731024
# ## Start training
10741025
#
1075-
max_logging.log("Starting GRPO training with goodput monitoring...")
10761026

10771027
jax.profiler.start_trace(PROFILE_DIR)
10781028
with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules):
1079-
with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
1080-
max_logging.log("Training preparation phase recorded")
1081-
1082-
# Record the main training phase
1083-
with maybe_record_goodput(recorder, GoodputEvent.STEP):
1084-
grpo_trainer.train(DATASET)
1085-
1029+
grpo_trainer.train(DATASET)
10861030
jax.profiler.stop_trace()
10871031

1088-
max_logging.log("GRPO training completed")
1089-
10901032
print("HBM usage after training:")
10911033
show_hbm_usage()
10921034

10931035
# ## Evaluate
10941036
#
10951037
# Let's evaluate our model!
1096-
max_logging.log("Starting post-training evaluation...")
1097-
1098-
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
1099-
# pylint: disable=unbalanced-tuple-unpacking
1100-
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
1101-
test_dataset,
1102-
rl_cluster,
1103-
**GENERATION_CONFIGS["greedy"],
1104-
)
1105-
print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
11061038

1107-
max_logging.log(f"Post-training evaluation completed: {accuracy}% accuracy")
1108-
max_logging.log("GRPO training with goodput monitoring finished successfully")
1039+
# pylint: disable=unbalanced-tuple-unpacking
1040+
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
1041+
test_dataset,
1042+
rl_cluster,
1043+
**GENERATION_CONFIGS["greedy"],
1044+
)
1045+
print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
11091046

11101047

11111048
if __name__ == "__main__":

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
from MaxText.data_loader import DataLoader
8080
from MaxText.experimental.rl import grpo_input_pipeline
8181
from MaxText.experimental.rl import grpo_utils
82+
from MaxText.experimental.rl.hooks import GRPOTrainingHooks, GRPODataHooks
8283
from MaxText.globals import EPS
8384
from MaxText.metric_logger import MetricLogger
8485
from MaxText.train import get_first_step
@@ -708,9 +709,18 @@ def train_loop(config, config_inference, recorder, state=None):
708709
data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder)
709710
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
710711

712+
# Initialize GRPO training hooks
713+
training_hooks = GRPOTrainingHooks(
714+
config=config, mesh=mesh, learning_rate_schedule=learning_rate_schedule, goodput_recorder=recorder
715+
)
716+
data_hooks = GRPODataHooks(config=config, data_iterator=data_iterator, eval_data_iterator=eval_data_iterator)
717+
711718
# Write train config params, num model params, and XLA flags to tensorboard
712719
metric_logger.write_setup_info_to_tensorboard(state.params["params"])
713720

721+
# Call on_train_start hook
722+
training_hooks.on_train_start(state, start_step)
723+
714724
def generation_worker_fn(
715725
worker_inference_engine,
716726
worker_tokenizer_model,
@@ -765,6 +775,9 @@ def generation_worker_fn(
765775
inference_engine_lock = threading.Lock()
766776

767777
max_logging.log("Inference Rollout")
778+
# Track initial generation
779+
training_hooks.on_generation_start(start_step)
780+
gen_start_time = time.time()
768781
generate_completions(
769782
data_loader,
770783
inference_engine,
@@ -776,6 +789,10 @@ def generation_worker_fn(
776789
data_sharding,
777790
inference_engine_lock,
778791
)
792+
gen_time = time.time() - gen_start_time
793+
with data_buffer_lock:
794+
num_completions = sum(batch[config.train_data_columns].shape[0] for batch in data_buffer)
795+
training_hooks.on_generation_end(start_step, num_completions, gen_time)
779796

780797
required_batch_size = int(config.per_device_batch_size * config.num_generations * mesh.size)
781798
generation_thread = threading.Thread(
@@ -798,6 +815,9 @@ def generation_worker_fn(
798815
try:
799816
last_step_completion = datetime.datetime.now()
800817
for step in np.arange(start_step, config.steps):
818+
# Call on_train_step_start hook
819+
training_hooks.on_train_step_start(step)
820+
801821
prof.maybe_activate_profiler(step, state)
802822

803823
with jax.profiler.StepTraceAnnotation("train", step_num=step):
@@ -837,7 +857,11 @@ def generation_worker_fn(
837857
last_step_completion = datetime.datetime.now()
838858

839859
state_to_save = _split_grpo_state(state)[0]
840-
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step)
860+
checkpoint_saved = checkpointing.maybe_save_checkpoint(
861+
checkpoint_manager, state_to_save, config, data_iterator, step
862+
)
863+
if checkpoint_saved:
864+
training_hooks.on_checkpoint_save(step, config.checkpoint_dir)
841865

842866
if config.dump_hlo and step == start_step:
843867
jax.block_until_ready(state) # Ensure compilation has finished.
@@ -851,17 +875,23 @@ def generation_worker_fn(
851875

852876
if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
853877
assert eval_data_iterator
878+
# Call on_eval_start hook
879+
training_hooks.on_eval_start(step)
854880
eval_step_count = 0
855881
# pylint: disable=not-callable
856882
for eval_batch in eval_data_iterator:
857883
if 0 < config.eval_steps <= eval_step_count:
858884
break
859885
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
860886
eval_metrics = p_eval_step(state, eval_batch, rng)
887+
# Call on_eval_step hook
888+
training_hooks.on_eval_step(eval_metrics)
861889
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
862890
max_logging.log(f"Completed eval step {eval_step_count}")
863891
eval_step_count += 1
864892
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
893+
# Call on_eval_end hook
894+
training_hooks.on_eval_end(step)
865895
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
866896
prof.deactivate()
867897
raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.")
@@ -872,11 +902,17 @@ def generation_worker_fn(
872902
max_utils.print_mem_stats("After params initialized")
873903

874904
metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)
905+
906+
# Call on_train_step_end hook
907+
training_hooks.on_train_step_end(step, metrics, step_time_delta.total_seconds())
908+
875909
state_to_save = _split_grpo_state(state)[0]
876910
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
877911
except exceptions.StopTraining as e:
878912
max_logging.log(f"Training stopped: {str(e)}")
879913
finally:
914+
# Call on_train_end hook
915+
training_hooks.on_train_end(step)
880916
metric_logger.flush_metrics_and_cleanup()
881917
max_logging.log("Training loop finished or exited. Signaling generation worker to stop.")
882918
stop_event.set()

0 commit comments

Comments
 (0)