|
50 | 50 | import transformers
|
51 | 51 |
|
52 | 52 | from MaxText import checkpointing
|
| 53 | +from MaxText import exceptions |
53 | 54 | from MaxText import max_logging
|
54 | 55 | from MaxText import max_utils
|
55 | 56 | from MaxText import maxengine
|
56 | 57 | from MaxText import maxtext_utils
|
57 | 58 | from MaxText import profiler
|
58 | 59 | from MaxText import pyconfig
|
59 | 60 | from MaxText.common_types import Array
|
| 61 | +from MaxText.data_loader import DataLoader |
60 | 62 | from MaxText.experimental.rl import grpo_input_pipeline
|
61 |
| -from MaxText.gcp_workload_monitor import GCPWorkloadMonitor |
62 | 63 | from MaxText.globals import EPS
|
63 | 64 | from MaxText.layers import models
|
64 | 65 | from MaxText.metric_logger import MetricLogger
|
65 | 66 | from MaxText.train import (
|
66 | 67 | validate_train_config,
|
67 | 68 | get_first_step,
|
68 |
| - load_next_batch, |
69 | 69 | save_checkpoint,
|
70 |
| - check_example_batch, |
71 | 70 | setup_mesh_and_model,
|
72 | 71 | )
|
73 | 72 | from MaxText.utils.goodput_utils import (
|
@@ -765,84 +764,77 @@ def train_loop(config, config_inference, recorder, state=None):
|
765 | 764 |
|
766 | 765 | example_batch = None
|
767 | 766 |
|
768 |
| - input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh) |
769 |
| - |
| 767 | + data_loader = DataLoader(config, mesh, data_iterator, recorder) |
770 | 768 | metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
|
771 | 769 |
|
772 | 770 | # Write train config params, num model params, and XLA flags to tensorboard
|
773 | 771 | metric_logger.write_setup_info_to_tensorboard(state.params)
|
774 | 772 |
|
775 |
| - for step in np.arange(start_step, config.steps): |
776 |
| - step_start_time = datetime.datetime.now() |
777 |
| - prof.maybe_activate_profiler(step, state) |
778 |
| - |
779 |
| - with jax.profiler.StepTraceAnnotation("train", step_num=step): |
780 |
| - with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING): |
781 |
| - try: |
782 |
| - example_batch = load_next_batch(data_iterator, example_batch, config) |
783 |
| - example_batch = jax.lax.with_sharding_constraint(example_batch, input_data_shardings) |
784 |
| - except Exception as e: # pylint: disable=broad-except |
785 |
| - max_logging.log(f"load_next_batch failed, you may have run out of data. Error message: {e}") |
786 |
| - break |
787 |
| - |
788 |
| - check_example_batch(config, example_batch=example_batch) |
789 |
| - # pylint: disable=not-callable |
790 |
| - rng = jax.jit(jax.random.fold_in)(init_rng, step) |
791 |
| - with maybe_record_goodput(recorder, GoodputEvent.STEP, step): |
792 |
| - rng, rng_gen = random.split(rng) |
793 |
| - example_batch = p_generate_completions(example_batch, state.params, rng_gen) |
794 |
| - |
795 |
| - # TODO: ensure this partitioning is correct |
796 |
| - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): |
797 |
| - state, metrics = p_train_step(state, example_batch, rng) |
798 |
| - |
799 |
| - if checkpoint_manager is not None: |
800 |
| - state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0] |
801 |
| - if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config): |
802 |
| - checkpointing.print_save_message(step, config.async_checkpointing) |
803 |
| - |
804 |
| - # Upon preemption, exit when and only when all ongoing saves are complete. |
805 |
| - if checkpoint_manager.reached_preemption(step): |
806 |
| - checkpoint_manager.wait_until_finished() |
807 |
| - sys.exit() |
808 |
| - |
809 |
| - if config.dump_hlo and step == start_step: |
810 |
| - jax.block_until_ready(state) # Ensure compilation has finished. |
811 |
| - max_utils.upload_dump( |
812 |
| - config.dump_hlo_local_dir, |
813 |
| - config.dump_hlo_gcs_dir, |
814 |
| - module_name=config.dump_hlo_module_name, |
815 |
| - delete_local_after=config.dump_hlo_delete_local_after, |
816 |
| - all_host_upload=config.dump_hlo_upload_all, |
817 |
| - ) |
| 773 | + try: |
| 774 | + for step in np.arange(start_step, config.steps): |
| 775 | + step_start_time = datetime.datetime.now() |
| 776 | + prof.maybe_activate_profiler(step, state) |
| 777 | + |
| 778 | + with jax.profiler.StepTraceAnnotation("train", step_num=step): |
| 779 | + example_batch = data_loader.load_next_batch() |
| 780 | + # pylint: disable=not-callable |
| 781 | + rng = jax.jit(jax.random.fold_in)(init_rng, step) |
| 782 | + with maybe_record_goodput(recorder, GoodputEvent.STEP, step): |
| 783 | + rng, rng_gen = random.split(rng) |
| 784 | + example_batch = p_generate_completions(example_batch, state.params, rng_gen) |
| 785 | + |
| 786 | + # TODO: ensure this partitioning is correct |
| 787 | + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): |
| 788 | + state, metrics = p_train_step(state, example_batch, rng) |
| 789 | + |
| 790 | + if checkpoint_manager is not None: |
| 791 | + state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0] |
| 792 | + if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config): |
| 793 | + checkpointing.print_save_message(step, config.async_checkpointing) |
| 794 | + |
| 795 | + # Upon preemption, exit when and only when all ongoing saves are complete. |
| 796 | + if checkpoint_manager.reached_preemption(step): |
| 797 | + checkpoint_manager.wait_until_finished() |
| 798 | + sys.exit() |
| 799 | + |
| 800 | + if config.dump_hlo and step == start_step: |
| 801 | + jax.block_until_ready(state) # Ensure compilation has finished. |
| 802 | + max_utils.upload_dump( |
| 803 | + config.dump_hlo_local_dir, |
| 804 | + config.dump_hlo_gcs_dir, |
| 805 | + module_name=config.dump_hlo_module_name, |
| 806 | + delete_local_after=config.dump_hlo_delete_local_after, |
| 807 | + all_host_upload=config.dump_hlo_upload_all, |
| 808 | + ) |
818 | 809 |
|
819 |
| - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: |
820 |
| - assert eval_data_iterator |
821 |
| - eval_step_count = 0 |
822 |
| - # pylint: disable=not-callable |
823 |
| - for eval_batch in eval_data_iterator: |
824 |
| - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: |
825 |
| - break |
826 |
| - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): |
827 |
| - eval_metrics = p_eval_step(state, eval_batch, rng) |
828 |
| - metric_logger.record_eval_metrics(step, metrics=eval_metrics) |
829 |
| - max_logging.log(f"Completed eval step {eval_step_count}") |
830 |
| - eval_step_count += 1 |
831 |
| - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) |
832 |
| - if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: |
833 |
| - max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") |
834 |
| - prof.deactivate() |
835 |
| - break |
836 |
| - |
837 |
| - prof.maybe_deactivate_profiler(step, state) |
838 |
| - |
839 |
| - if step == start_step: |
840 |
| - max_utils.print_mem_stats("After params initialized") |
841 |
| - |
842 |
| - jax.block_until_ready(state) # ensure training step is completed |
843 |
| - |
844 |
| - step_time_delta = datetime.datetime.now() - step_start_time |
845 |
| - metric_logger.record_train_metrics(metrics, step, step_time_delta) |
| 810 | + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: |
| 811 | + assert eval_data_iterator |
| 812 | + eval_step_count = 0 |
| 813 | + # pylint: disable=not-callable |
| 814 | + for eval_batch in eval_data_iterator: |
| 815 | + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: |
| 816 | + break |
| 817 | + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): |
| 818 | + eval_metrics = p_eval_step(state, eval_batch, rng) |
| 819 | + metric_logger.record_eval_metrics(step, metrics=eval_metrics) |
| 820 | + max_logging.log(f"Completed eval step {eval_step_count}") |
| 821 | + eval_step_count += 1 |
| 822 | + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) |
| 823 | + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: |
| 824 | + prof.deactivate() |
| 825 | + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") |
| 826 | + |
| 827 | + prof.maybe_deactivate_profiler(step, state) |
| 828 | + |
| 829 | + if step == start_step: |
| 830 | + max_utils.print_mem_stats("After params initialized") |
| 831 | + |
| 832 | + jax.block_until_ready(state) # ensure training step is completed |
| 833 | + |
| 834 | + step_time_delta = datetime.datetime.now() - step_start_time |
| 835 | + metric_logger.record_train_metrics(metrics, step, step_time_delta) |
| 836 | + except exceptions.StopTraining as e: |
| 837 | + max_logging.log(f"Training stopped: {str(e)}") |
846 | 838 |
|
847 | 839 | if checkpoint_manager is not None:
|
848 | 840 | if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):
|
|
0 commit comments