|
| 1 | +""" |
| 2 | +Copyright 2025 Google LLC |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +# pylint: disable=bare-except, consider-using-generator |
| 18 | +"""Utils that are only interesting to MaxText. To break the circular dependency, all the util functions relying on MaxText.checkpointing is here.""" |
| 19 | + |
| 20 | +import jax |
| 21 | + |
| 22 | +from MaxText import max_utils |
| 23 | + |
| 24 | +import functools |
| 25 | + |
| 26 | + |
| 27 | +from flax.linen import partitioning as nn_partitioning |
| 28 | + |
| 29 | +from MaxText import max_logging |
| 30 | +from MaxText import maxtext_utils |
| 31 | +from MaxText import checkpointing |
| 32 | +import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager |
| 33 | +import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager |
| 34 | + |
| 35 | + |
| 36 | +def setup_initial_state( |
| 37 | + model, |
| 38 | + data_iterator, |
| 39 | + tx, |
| 40 | + config, |
| 41 | + rng, |
| 42 | + mesh, |
| 43 | + checkpoint_manager, |
| 44 | + is_training=True, |
| 45 | +): |
| 46 | + """We initialize the model and optimizer state, and optionally load from a |
| 47 | + checkpoint as necessary. |
| 48 | +
|
| 49 | + Args: |
| 50 | + model: the flax model to initialize |
| 51 | + tx: the optax.GradientTransformation |
| 52 | + config: config object |
| 53 | + rng: jax.prng key |
| 54 | + mesh: jax.devices() mesh |
| 55 | + checkpoint_manager: an Orbax checkpointing.CheckpointManager object |
| 56 | + is_training: True to initialize training state, False for decode state |
| 57 | +
|
| 58 | + Returns: |
| 59 | + state: the initialized train state |
| 60 | + state_mesh_annotations: the mesh annotations for the train state |
| 61 | + """ |
| 62 | + |
| 63 | + unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( |
| 64 | + model, tx, config, rng, mesh, is_training |
| 65 | + ) |
| 66 | + |
| 67 | + # Initialization |
| 68 | + with nn_partitioning.axis_rules(config.logical_axis_rules): |
| 69 | + restored, raw_params = checkpointing.load_state_if_possible( |
| 70 | + checkpoint_manager, |
| 71 | + data_iterator, |
| 72 | + config.load_parameters_path, |
| 73 | + config.load_full_state_path, |
| 74 | + config.checkpoint_storage_concurrent_gb, |
| 75 | + unboxed_abstract_state, |
| 76 | + config.enable_single_replica_ckpt_restoring, |
| 77 | + config.dataset_type, |
| 78 | + use_ocdbt=config.checkpoint_storage_use_ocdbt, |
| 79 | + use_zarr3=config.checkpoint_storage_use_zarr3, |
| 80 | + ) |
| 81 | + |
| 82 | + if restored: |
| 83 | + if isinstance( |
| 84 | + checkpoint_manager, |
| 85 | + ( |
| 86 | + emergency_checkpoint_manager.CheckpointManager, |
| 87 | + emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager, |
| 88 | + ), |
| 89 | + ): |
| 90 | + state = restored |
| 91 | + else: |
| 92 | + if "iter" in restored and restored["iter"] is not None: |
| 93 | + data_iterator.local_iterator = restored["iter"] |
| 94 | + state = restored["items"] |
| 95 | + else: |
| 96 | + init_state_partial = functools.partial(maxtext_utils.init_initial_state, model, tx, config, is_training) |
| 97 | + init_state_partial.__name__ = "initialize_state" |
| 98 | + # pylint: disable=not-callable |
| 99 | + state = jax.jit( |
| 100 | + init_state_partial, |
| 101 | + in_shardings=None, |
| 102 | + out_shardings=state_mesh_shardings, |
| 103 | + )(rng) |
| 104 | + if raw_params: # If we loaded a partial state, we need to merge it. |
| 105 | + state = state.replace(params=raw_params) |
| 106 | + |
| 107 | + state = max_utils.unbox_logicallypartioned(state) |
| 108 | + |
| 109 | + return state, state_mesh_annotations, state_mesh_shardings, data_iterator |
| 110 | + |
| 111 | + |
| 112 | +def setup_decode_state(model, config, rng, mesh, checkpoint_manager): |
| 113 | + """Setup decode state by loading params from a checkpoint. |
| 114 | + Args: |
| 115 | + model: the flax model to initialize |
| 116 | + config: config object |
| 117 | + rng: jax.prng key |
| 118 | + mesh: jax.devices() mesh |
| 119 | + checkpoint_manager: Checkpoint manager |
| 120 | +
|
| 121 | + Returns: |
| 122 | + state: state with decode params loaded from the checkpoint |
| 123 | + state_mesh_annotations: the mesh annotations for the state |
| 124 | + """ |
| 125 | + if not config.load_parameters_path: |
| 126 | + # generate random params |
| 127 | + max_logging.log("No decode checkpoint specified - generating random weights.") |
| 128 | + state, state_mesh_annotations, _, _ = setup_initial_state( |
| 129 | + model, None, None, config, rng, mesh, checkpoint_manager, False |
| 130 | + ) |
| 131 | + else: |
| 132 | + # Load params from checkpoint |
| 133 | + max_logging.log(f"Loading decode params from {config.load_parameters_path}") |
| 134 | + unboxed_abstract_state, state_mesh_annotations, _ = maxtext_utils.get_abstract_state(model, None, config, rng, mesh, False) |
| 135 | + with nn_partitioning.axis_rules(config.logical_axis_rules): |
| 136 | + params = checkpointing.load_params_from_path( |
| 137 | + config.load_parameters_path, unboxed_abstract_state.params, config.checkpoint_storage_concurrent_gb, config.checkpoint_storage_use_ocdbt, config.checkpoint_storage_use_zarr3 |
| 138 | + ) |
| 139 | + state = maxtext_utils.init_decode_state(None, params) |
| 140 | + |
| 141 | + state = max_utils.unbox_logicallypartioned(state) |
| 142 | + return state, state_mesh_annotations |
| 143 | + |
| 144 | + |
| 145 | + |
| 146 | +def save_quantized_checkpoint_if_configured(config, params): |
| 147 | + assert config.quantization, "quantization must be configured" |
| 148 | + if config.save_quantized_params_path: |
| 149 | + checkpointing.save_params_to_path(config.save_quantized_params_path, params) |
| 150 | + else: |
| 151 | + "Skipping saving quantized checkpoint as save_quantized_params_path is null." |
| 152 | + |
| 153 | + |
| 154 | +def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): |
| 155 | + is_training = True |
| 156 | + return setup_initial_state( |
| 157 | + model, |
| 158 | + data_iterator, |
| 159 | + tx, |
| 160 | + config, |
| 161 | + rng, |
| 162 | + mesh, |
| 163 | + checkpoint_manager, |
| 164 | + is_training, |
| 165 | + ) |
| 166 | + |
0 commit comments