Skip to content

Commit f3a0a49

Browse files
committed
Moving the checkpointing util funtions to a spearate file to avoid circular dependency.
1 parent 7ffd13d commit f3a0a49

19 files changed

+263
-224
lines changed

MaxText/convert_gpt3_ckpt_from_paxml.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import jax
5353
import gc
5454
from MaxText import max_logging
55+
from MaxText import maxtext_state_initialization_utils
5556
from psutil import Process
5657
from MaxText.train import save_checkpoint
5758
import argparse
@@ -109,7 +110,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
109110
cfg.checkpoint_period,
110111
)
111112

112-
state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
113+
state, _, _, _ = maxtext_state_initialization_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
113114
max_logging.log("start")
114115
check_memory()
115116

MaxText/elastic_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from MaxText import checkpointing
6262
from MaxText import max_utils
6363
from MaxText import maxtext_utils
64+
from MaxText import maxtext_state_initialization_utils
6465
from MaxText import max_logging
6566
from MaxText import profiler
6667
from MaxText import pyconfig
@@ -133,7 +134,7 @@ def elastic_handler(
133134
max_logging.log(f"Deleting checkpoint from step {latest_step} since we are rewinding to step {step}.")
134135
checkpoint_manager.delete(latest_step)
135136

136-
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
137+
state, _, state_mesh_shardings, data_iterator = maxtext_state_initialization_utils.setup_training_state(
137138
model,
138139
data_iterator,
139140
tx,

MaxText/experimental/rl/grpo_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from MaxText import checkpointing
4343
from MaxText import max_utils
4444
from MaxText import maxtext_utils
45+
from MaxText import maxtext_state_initialization_utils
4546
from MaxText import max_logging
4647
from MaxText import profiler
4748
from MaxText import pyconfig
@@ -649,7 +650,7 @@ def setup_train_loop(config):
649650
record_goodput(recorder, config, recorder.record_tpu_init_end_time if recorder else None)
650651
record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None)
651652
data_iterator, eval_data_iterator = grpo_input_pipeline.create_data_iterator(config, mesh)
652-
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
653+
state, _, state_mesh_shardings, data_iterator = maxtext_state_initialization_utils.setup_training_state(
653654
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
654655
)
655656

MaxText/generate_param_only_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import jax
2929
from MaxText import max_logging
3030
from MaxText import max_utils
31+
from MaxText import maxtext_state_initialization_utils
3132
from MaxText import maxtext_utils
3233
from MaxText import optimizers
3334
from MaxText import pyconfig
@@ -97,7 +98,7 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
9798
rng = random.PRNGKey(0)
9899
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
99100
tx = optimizers.get_optimizer(config, learning_rate_schedule)
100-
state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state(
101+
state, state_mesh_notations, _, _ = maxtext_state_initialization_utils.setup_training_state(
101102
model, None, tx, config, rng, mesh, checkpoint_manager
102103
)
103104
num_params = max_utils.calculate_num_params_from_pytree(state.params)

MaxText/maxengine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from MaxText import max_utils
4646
from MaxText import maxtext_utils
47+
from MaxText import maxtext_state_initialization_utils
4748
from MaxText import inference_utils
4849
from MaxText import pyconfig
4950

@@ -234,7 +235,7 @@ def load_params(self, *args, params=None, rng: Optional[PRNGKeyType] = None, **k
234235
state = maxtext_utils.init_decode_state(None, params)
235236
state = max_utils.unbox_logicallypartioned(state)
236237
else:
237-
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.model, self.config, rng1, self._mesh, None)
238+
state, self.state_mesh_annotations = maxtext_state_initialization_utils.setup_decode_state(self.model, self.config, rng1, self._mesh, None)
238239
# pylint: disable=isinstance-second-argument-not-valid-type
239240
self.abstract_params = jax.tree_util.tree_map(
240241
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
@@ -341,7 +342,7 @@ def model_apply(_p, _rng):
341342
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding),
342343
params,
343344
)
344-
maxtext_utils.save_quantized_checkpoint_if_configured(self.config, params)
345+
maxtext_state_initialization_utils.save_quantized_checkpoint_if_configured(self.config, params)
345346
self.model.quant.quant_mode = quantizations.get_quant_mode("serve")
346347
return params
347348

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
# pylint: disable=bare-except, consider-using-generator
18+
"""Utils that are only interesting to MaxText. To break the circular dependency, all the util functions relying on MaxText.checkpointing is here."""
19+
20+
import jax
21+
22+
from MaxText import max_utils
23+
24+
import functools
25+
26+
27+
from flax.linen import partitioning as nn_partitioning
28+
29+
from MaxText import max_logging
30+
from MaxText import maxtext_utils
31+
from MaxText import checkpointing
32+
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
33+
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
34+
35+
36+
def setup_initial_state(
37+
model,
38+
data_iterator,
39+
tx,
40+
config,
41+
rng,
42+
mesh,
43+
checkpoint_manager,
44+
is_training=True,
45+
):
46+
"""We initialize the model and optimizer state, and optionally load from a
47+
checkpoint as necessary.
48+
49+
Args:
50+
model: the flax model to initialize
51+
tx: the optax.GradientTransformation
52+
config: config object
53+
rng: jax.prng key
54+
mesh: jax.devices() mesh
55+
checkpoint_manager: an Orbax checkpointing.CheckpointManager object
56+
is_training: True to initialize training state, False for decode state
57+
58+
Returns:
59+
state: the initialized train state
60+
state_mesh_annotations: the mesh annotations for the train state
61+
"""
62+
63+
unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
64+
model, tx, config, rng, mesh, is_training
65+
)
66+
67+
# Initialization
68+
with nn_partitioning.axis_rules(config.logical_axis_rules):
69+
restored, raw_params = checkpointing.load_state_if_possible(
70+
checkpoint_manager,
71+
data_iterator,
72+
config.load_parameters_path,
73+
config.load_full_state_path,
74+
config.checkpoint_storage_concurrent_gb,
75+
unboxed_abstract_state,
76+
config.enable_single_replica_ckpt_restoring,
77+
config.dataset_type,
78+
use_ocdbt=config.checkpoint_storage_use_ocdbt,
79+
use_zarr3=config.checkpoint_storage_use_zarr3,
80+
)
81+
82+
if restored:
83+
if isinstance(
84+
checkpoint_manager,
85+
(
86+
emergency_checkpoint_manager.CheckpointManager,
87+
emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager,
88+
),
89+
):
90+
state = restored
91+
else:
92+
if "iter" in restored and restored["iter"] is not None:
93+
data_iterator.local_iterator = restored["iter"]
94+
state = restored["items"]
95+
else:
96+
init_state_partial = functools.partial(maxtext_utils.init_initial_state, model, tx, config, is_training)
97+
init_state_partial.__name__ = "initialize_state"
98+
# pylint: disable=not-callable
99+
state = jax.jit(
100+
init_state_partial,
101+
in_shardings=None,
102+
out_shardings=state_mesh_shardings,
103+
)(rng)
104+
if raw_params: # If we loaded a partial state, we need to merge it.
105+
state = state.replace(params=raw_params)
106+
107+
state = max_utils.unbox_logicallypartioned(state)
108+
109+
return state, state_mesh_annotations, state_mesh_shardings, data_iterator
110+
111+
112+
def setup_decode_state(model, config, rng, mesh, checkpoint_manager):
113+
"""Setup decode state by loading params from a checkpoint.
114+
Args:
115+
model: the flax model to initialize
116+
config: config object
117+
rng: jax.prng key
118+
mesh: jax.devices() mesh
119+
checkpoint_manager: Checkpoint manager
120+
121+
Returns:
122+
state: state with decode params loaded from the checkpoint
123+
state_mesh_annotations: the mesh annotations for the state
124+
"""
125+
if not config.load_parameters_path:
126+
# generate random params
127+
max_logging.log("No decode checkpoint specified - generating random weights.")
128+
state, state_mesh_annotations, _, _ = setup_initial_state(
129+
model, None, None, config, rng, mesh, checkpoint_manager, False
130+
)
131+
else:
132+
# Load params from checkpoint
133+
max_logging.log(f"Loading decode params from {config.load_parameters_path}")
134+
unboxed_abstract_state, state_mesh_annotations, _ = maxtext_utils.get_abstract_state(model, None, config, rng, mesh, False)
135+
with nn_partitioning.axis_rules(config.logical_axis_rules):
136+
params = checkpointing.load_params_from_path(
137+
config.load_parameters_path, unboxed_abstract_state.params, config.checkpoint_storage_concurrent_gb, config.checkpoint_storage_use_ocdbt, config.checkpoint_storage_use_zarr3
138+
)
139+
state = maxtext_utils.init_decode_state(None, params)
140+
141+
state = max_utils.unbox_logicallypartioned(state)
142+
return state, state_mesh_annotations
143+
144+
145+
146+
def save_quantized_checkpoint_if_configured(config, params):
147+
assert config.quantization, "quantization must be configured"
148+
if config.save_quantized_params_path:
149+
checkpointing.save_params_to_path(config.save_quantized_params_path, params)
150+
else:
151+
"Skipping saving quantized checkpoint as save_quantized_params_path is null."
152+
153+
154+
def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager):
155+
is_training = True
156+
return setup_initial_state(
157+
model,
158+
data_iterator,
159+
tx,
160+
config,
161+
rng,
162+
mesh,
163+
checkpoint_manager,
164+
is_training,
165+
)
166+

0 commit comments

Comments
 (0)