Skip to content

Moving the checkpointing util funtions to a spearate file to avoid circular dependency #1623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import jax
import gc
from MaxText import max_logging
from MaxText import maxtext_state_initialization_utils
from psutil import Process
from MaxText.train import save_checkpoint
import argparse
Expand Down Expand Up @@ -109,7 +110,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
state, _, _, _ = maxtext_state_initialization_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
max_logging.log("start")
check_memory()

Expand Down
3 changes: 2 additions & 1 deletion MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from MaxText import checkpointing
from MaxText import max_utils
from MaxText import maxtext_utils
from MaxText import maxtext_state_initialization_utils
from MaxText import max_logging
from MaxText import profiler
from MaxText import pyconfig
Expand Down Expand Up @@ -133,7 +134,7 @@ def elastic_handler(
max_logging.log(f"Deleting checkpoint from step {latest_step} since we are rewinding to step {step}.")
checkpoint_manager.delete(latest_step)

state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
state, _, state_mesh_shardings, data_iterator = maxtext_state_initialization_utils.setup_training_state(
model,
data_iterator,
tx,
Expand Down
3 changes: 2 additions & 1 deletion MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from MaxText import checkpointing
from MaxText import max_utils
from MaxText import maxtext_utils
from MaxText import maxtext_state_initialization_utils
from MaxText import max_logging
from MaxText import profiler
from MaxText import pyconfig
Expand Down Expand Up @@ -649,7 +650,7 @@ def setup_train_loop(config):
record_goodput(recorder, config, recorder.record_tpu_init_end_time if recorder else None)
record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None)
data_iterator, eval_data_iterator = grpo_input_pipeline.create_data_iterator(config, mesh)
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
state, _, state_mesh_shardings, data_iterator = maxtext_state_initialization_utils.setup_training_state(
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
)

Expand Down
3 changes: 2 additions & 1 deletion MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import jax
from MaxText import max_logging
from MaxText import max_utils
from MaxText import maxtext_state_initialization_utils
from MaxText import maxtext_utils
from MaxText import optimizers
from MaxText import pyconfig
Expand Down Expand Up @@ -98,7 +99,7 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
rng = random.PRNGKey(0)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state(
state, state_mesh_notations, _, _ = maxtext_state_initialization_utils.setup_training_state(
model, None, tx, config, rng, mesh, checkpoint_manager
)
num_params = max_utils.calculate_num_params_from_pytree(state.params)
Expand Down
5 changes: 3 additions & 2 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from MaxText import max_utils
from MaxText import maxtext_utils
from MaxText import maxtext_state_initialization_utils
from MaxText import inference_utils
from MaxText import pyconfig
from MaxText import common_types
Expand Down Expand Up @@ -233,7 +234,7 @@ def load_params(self, *args, params=None, rng: Optional[PRNGKeyType] = None, **k
state = maxtext_utils.init_decode_state(None, params)
state = max_utils.unbox_logicallypartioned(state)
else:
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.model, self.config, rng1, self._mesh, None)
state, self.state_mesh_annotations = maxtext_state_initialization_utils.setup_decode_state(self.model, self.config, rng1, self._mesh, None)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
Expand Down Expand Up @@ -340,7 +341,7 @@ def model_apply(_p, _rng):
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding),
params,
)
maxtext_utils.save_quantized_checkpoint_if_configured(self.config, params)
maxtext_state_initialization_utils.save_quantized_checkpoint_if_configured(self.config, params)
self.model.quant.quant_mode = quantizations.get_quant_mode("serve")
return params

Expand Down
166 changes: 166 additions & 0 deletions MaxText/maxtext_state_initialization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# pylint: disable=bare-except, consider-using-generator
"""Utils that are only interesting to MaxText. To break the circular dependency, all the util functions relying on MaxText.checkpointing is here."""

import jax

from MaxText import max_utils

import functools


from flax.linen import partitioning as nn_partitioning

from MaxText import max_logging
from MaxText import maxtext_utils
from MaxText import checkpointing
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager


def setup_initial_state(
model,
data_iterator,
tx,
config,
rng,
mesh,
checkpoint_manager,
is_training=True,
):
"""We initialize the model and optimizer state, and optionally load from a
checkpoint as necessary.

Args:
model: the flax model to initialize
tx: the optax.GradientTransformation
config: config object
rng: jax.prng key
mesh: jax.devices() mesh
checkpoint_manager: an Orbax checkpointing.CheckpointManager object
is_training: True to initialize training state, False for decode state

Returns:
state: the initialized train state
state_mesh_annotations: the mesh annotations for the train state
"""

unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
model, tx, config, rng, mesh, is_training
)

# Initialization
with nn_partitioning.axis_rules(config.logical_axis_rules):
restored, raw_params = checkpointing.load_state_if_possible(
checkpoint_manager,
data_iterator,
config.load_parameters_path,
config.load_full_state_path,
config.checkpoint_storage_concurrent_gb,
unboxed_abstract_state,
config.enable_single_replica_ckpt_restoring,
config.dataset_type,
use_ocdbt=config.checkpoint_storage_use_ocdbt,
use_zarr3=config.checkpoint_storage_use_zarr3,
)

if restored:
if isinstance(
checkpoint_manager,
(
emergency_checkpoint_manager.CheckpointManager,
emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager,
),
):
state = restored
else:
if "iter" in restored and restored["iter"] is not None:
data_iterator.local_iterator = restored["iter"]
state = restored["items"]
else:
init_state_partial = functools.partial(maxtext_utils.init_initial_state, model, tx, config, is_training)
init_state_partial.__name__ = "initialize_state"
# pylint: disable=not-callable
state = jax.jit(
init_state_partial,
in_shardings=None,
out_shardings=state_mesh_shardings,
)(rng)
if raw_params: # If we loaded a partial state, we need to merge it.
state = state.replace(params=raw_params)

state = max_utils.unbox_logicallypartioned(state)

return state, state_mesh_annotations, state_mesh_shardings, data_iterator


def setup_decode_state(model, config, rng, mesh, checkpoint_manager):
"""Setup decode state by loading params from a checkpoint.
Args:
model: the flax model to initialize
config: config object
rng: jax.prng key
mesh: jax.devices() mesh
checkpoint_manager: Checkpoint manager

Returns:
state: state with decode params loaded from the checkpoint
state_mesh_annotations: the mesh annotations for the state
"""
if not config.load_parameters_path:
# generate random params
max_logging.log("No decode checkpoint specified - generating random weights.")
state, state_mesh_annotations, _, _ = setup_initial_state(
model, None, None, config, rng, mesh, checkpoint_manager, False
)
else:
# Load params from checkpoint
max_logging.log(f"Loading decode params from {config.load_parameters_path}")
unboxed_abstract_state, state_mesh_annotations, _ = maxtext_utils.get_abstract_state(model, None, config, rng, mesh, False)
with nn_partitioning.axis_rules(config.logical_axis_rules):
params = checkpointing.load_params_from_path(
config.load_parameters_path, unboxed_abstract_state.params, config.checkpoint_storage_concurrent_gb, config.checkpoint_storage_use_ocdbt, config.checkpoint_storage_use_zarr3
)
state = maxtext_utils.init_decode_state(None, params)

state = max_utils.unbox_logicallypartioned(state)
return state, state_mesh_annotations



def save_quantized_checkpoint_if_configured(config, params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you have some inference folks look at this change @vipannalla or @mitalisi

assert config.quantization, "quantization must be configured"
if config.save_quantized_params_path:
checkpointing.save_params_to_path(config.save_quantized_params_path, params)
else:
"Skipping saving quantized checkpoint as save_quantized_params_path is null."


def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager):
is_training = True
return setup_initial_state(
model,
data_iterator,
tx,
config,
rng,
mesh,
checkpoint_manager,
is_training,
)

Loading
Loading