|
17 | 17 | This tutorial demonstrates training the Llama3.1 8B-IT model on |
18 | 18 | the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). |
19 | 19 | GRPO can enhance your model's problem-solving skills on mathematical word problems, |
20 | | - coding problems, etc. |
21 | | -
|
22 | | -GOODPUT MONITORING FEATURES: |
23 | | -- Automatic goodput measurement and tracking |
24 | | -- Badput breakdown analysis (non-productive time tracking) |
25 | | -- Step time deviation monitoring |
26 | | -- TensorBoard and Google Cloud Monitoring integration |
27 | | -- Performance metrics upload to GCM |
28 | | -- Real-time training efficiency monitoring |
29 | | -""" |
| 20 | + coding problems, etc. """ |
30 | 21 |
|
31 | 22 | # This tutorial demonstrates training the Llama3.1 8B-IT model on the GSM8K math |
32 | 23 | # reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can |
|
89 | 80 | from MaxText import pyconfig |
90 | 81 | from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter |
91 | 82 |
|
92 | | -# MaxText goodput monitoring imports |
93 | | -from MaxText.utils.goodput_utils import maybe_monitor_goodput, create_goodput_recorder, maybe_record_goodput, GoodputEvent |
94 | | -from MaxText import max_logging |
95 | | - |
96 | 83 | # This is for running the script in a colab or notebook environment. |
97 | 84 | # import nest_asyncio |
98 | 85 | # nest_asyncio.apply() # To fix "This event loop is already running" error in Colab |
|
144 | 131 | # ====== Reproducibility ====== |
145 | 132 | SEED = 42 |
146 | 133 |
|
147 | | -# ====== Goodput Monitoring ====== |
148 | | -# Enable goodput monitoring for performance tracking |
149 | | -ENABLE_GOODPUT_RECORDING = True |
150 | | -MONITOR_GOODPUT = True |
151 | | -ENABLE_GCP_GOODPUT_METRICS = True |
152 | | -ENABLE_GCP_STEP_DEVIATION_METRICS = True |
153 | | -GOODPUT_UPLOAD_INTERVAL_SECONDS = 30 |
154 | | - |
155 | 134 |
|
156 | 135 | # ====== GRPO ====== |
157 | 136 | # === Generation during GRPO training === |
@@ -929,30 +908,6 @@ def evaluate( |
929 | 908 | # Let's set up all the configs first - checkpointing, metric logging and training. |
930 | 909 | # We then train the model. |
931 | 910 | def main(): |
932 | | - # Create a mock config object for goodput monitoring |
933 | | - class MockConfig: |
934 | | - |
935 | | - def __init__(self): |
936 | | - self.monitor_goodput = MONITOR_GOODPUT |
937 | | - self.enable_goodput_recording = ENABLE_GOODPUT_RECORDING |
938 | | - self.enable_gcp_goodput_metrics = ENABLE_GCP_GOODPUT_METRICS |
939 | | - self.enable_gcp_step_deviation_metrics = ENABLE_GCP_STEP_DEVIATION_METRICS |
940 | | - self.goodput_upload_interval_seconds = GOODPUT_UPLOAD_INTERVAL_SECONDS |
941 | | - self.run_name = "grpo_llama3_demo" |
942 | | - self.tensorboard_dir = LOG_DIR |
943 | | - self.enable_pathways_goodput = False |
944 | | - self.monitor_step_time_deviation = True |
945 | | - self.step_deviation_interval_seconds = 60 |
946 | | - self.report_performance_metric_for_gcp_monitoring = False |
947 | | - |
948 | | - config = MockConfig() |
949 | | - |
950 | | - # Initialize goodput monitoring |
951 | | - maybe_monitor_goodput(config) |
952 | | - recorder = create_goodput_recorder(config) |
953 | | - |
954 | | - max_logging.log("GRPO training with goodput monitoring started") |
955 | | - |
956 | 911 | # Ckpt saving |
957 | 912 | checkpointing_options = ocp.CheckpointManagerOptions(save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP) |
958 | 913 |
|
@@ -1057,55 +1012,37 @@ def __init__(self): |
1057 | 1012 |
|
1058 | 1013 | # ## Evaluate before training |
1059 | 1014 | # |
1060 | | - max_logging.log("Starting pre-training evaluation...") |
1061 | | - |
1062 | | - with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING): |
1063 | | - # pylint: disable=unbalanced-tuple-unpacking |
1064 | | - (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( |
1065 | | - test_dataset, |
1066 | | - rl_cluster, |
1067 | | - **GENERATION_CONFIGS["greedy"], |
1068 | | - ) |
1069 | | - print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") |
1070 | 1015 |
|
1071 | | - max_logging.log(f"Pre-training evaluation completed: {accuracy}% accuracy") |
| 1016 | + # pylint: disable=unbalanced-tuple-unpacking |
| 1017 | + (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( |
| 1018 | + test_dataset, |
| 1019 | + rl_cluster, |
| 1020 | + **GENERATION_CONFIGS["greedy"], |
| 1021 | + ) |
| 1022 | + print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") |
1072 | 1023 |
|
1073 | 1024 | # ## Start training |
1074 | 1025 | # |
1075 | | - max_logging.log("Starting GRPO training with goodput monitoring...") |
1076 | 1026 |
|
1077 | 1027 | jax.profiler.start_trace(PROFILE_DIR) |
1078 | 1028 | with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules): |
1079 | | - with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): |
1080 | | - max_logging.log("Training preparation phase recorded") |
1081 | | - |
1082 | | - # Record the main training phase |
1083 | | - with maybe_record_goodput(recorder, GoodputEvent.STEP): |
1084 | | - grpo_trainer.train(DATASET) |
1085 | | - |
| 1029 | + grpo_trainer.train(DATASET) |
1086 | 1030 | jax.profiler.stop_trace() |
1087 | 1031 |
|
1088 | | - max_logging.log("GRPO training completed") |
1089 | | - |
1090 | 1032 | print("HBM usage after training:") |
1091 | 1033 | show_hbm_usage() |
1092 | 1034 |
|
1093 | 1035 | # ## Evaluate |
1094 | 1036 | # |
1095 | 1037 | # Let's evaluate our model! |
1096 | | - max_logging.log("Starting post-training evaluation...") |
1097 | | - |
1098 | | - with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING): |
1099 | | - # pylint: disable=unbalanced-tuple-unpacking |
1100 | | - (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( |
1101 | | - test_dataset, |
1102 | | - rl_cluster, |
1103 | | - **GENERATION_CONFIGS["greedy"], |
1104 | | - ) |
1105 | | - print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") |
1106 | 1038 |
|
1107 | | - max_logging.log(f"Post-training evaluation completed: {accuracy}% accuracy") |
1108 | | - max_logging.log("GRPO training with goodput monitoring finished successfully") |
| 1039 | + # pylint: disable=unbalanced-tuple-unpacking |
| 1040 | + (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( |
| 1041 | + test_dataset, |
| 1042 | + rl_cluster, |
| 1043 | + **GENERATION_CONFIGS["greedy"], |
| 1044 | + ) |
| 1045 | + print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") |
1109 | 1046 |
|
1110 | 1047 |
|
1111 | 1048 | if __name__ == "__main__": |
|
0 commit comments