-
Notifications
You must be signed in to change notification settings - Fork 368
Moving the checkpointing util funtions to a spearate file to avoid circular dependency #1623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wang2yn84
wants to merge
1
commit into
main
Choose a base branch
from
lance-refactor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
""" | ||
Copyright 2025 Google LLC | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
# pylint: disable=bare-except, consider-using-generator | ||
"""Utils that are only interesting to MaxText. To break the circular dependency, all the util functions relying on MaxText.checkpointing is here.""" | ||
|
||
import jax | ||
|
||
from MaxText import max_utils | ||
|
||
import functools | ||
|
||
|
||
from flax.linen import partitioning as nn_partitioning | ||
|
||
from MaxText import max_logging | ||
from MaxText import maxtext_utils | ||
from MaxText import checkpointing | ||
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager | ||
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager | ||
|
||
|
||
def setup_initial_state( | ||
model, | ||
data_iterator, | ||
tx, | ||
config, | ||
rng, | ||
mesh, | ||
checkpoint_manager, | ||
is_training=True, | ||
): | ||
"""We initialize the model and optimizer state, and optionally load from a | ||
checkpoint as necessary. | ||
|
||
Args: | ||
model: the flax model to initialize | ||
tx: the optax.GradientTransformation | ||
config: config object | ||
rng: jax.prng key | ||
mesh: jax.devices() mesh | ||
checkpoint_manager: an Orbax checkpointing.CheckpointManager object | ||
is_training: True to initialize training state, False for decode state | ||
|
||
Returns: | ||
state: the initialized train state | ||
state_mesh_annotations: the mesh annotations for the train state | ||
""" | ||
|
||
unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( | ||
model, tx, config, rng, mesh, is_training | ||
) | ||
|
||
# Initialization | ||
with nn_partitioning.axis_rules(config.logical_axis_rules): | ||
restored, raw_params = checkpointing.load_state_if_possible( | ||
checkpoint_manager, | ||
data_iterator, | ||
config.load_parameters_path, | ||
config.load_full_state_path, | ||
config.checkpoint_storage_concurrent_gb, | ||
unboxed_abstract_state, | ||
config.enable_single_replica_ckpt_restoring, | ||
config.dataset_type, | ||
use_ocdbt=config.checkpoint_storage_use_ocdbt, | ||
use_zarr3=config.checkpoint_storage_use_zarr3, | ||
) | ||
|
||
if restored: | ||
if isinstance( | ||
checkpoint_manager, | ||
( | ||
emergency_checkpoint_manager.CheckpointManager, | ||
emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager, | ||
), | ||
): | ||
state = restored | ||
else: | ||
if "iter" in restored and restored["iter"] is not None: | ||
data_iterator.local_iterator = restored["iter"] | ||
state = restored["items"] | ||
else: | ||
init_state_partial = functools.partial(maxtext_utils.init_initial_state, model, tx, config, is_training) | ||
init_state_partial.__name__ = "initialize_state" | ||
# pylint: disable=not-callable | ||
state = jax.jit( | ||
init_state_partial, | ||
in_shardings=None, | ||
out_shardings=state_mesh_shardings, | ||
)(rng) | ||
if raw_params: # If we loaded a partial state, we need to merge it. | ||
state = state.replace(params=raw_params) | ||
|
||
state = max_utils.unbox_logicallypartioned(state) | ||
|
||
return state, state_mesh_annotations, state_mesh_shardings, data_iterator | ||
|
||
|
||
def setup_decode_state(model, config, rng, mesh, checkpoint_manager): | ||
"""Setup decode state by loading params from a checkpoint. | ||
Args: | ||
model: the flax model to initialize | ||
config: config object | ||
rng: jax.prng key | ||
mesh: jax.devices() mesh | ||
checkpoint_manager: Checkpoint manager | ||
|
||
Returns: | ||
state: state with decode params loaded from the checkpoint | ||
state_mesh_annotations: the mesh annotations for the state | ||
""" | ||
if not config.load_parameters_path: | ||
# generate random params | ||
max_logging.log("No decode checkpoint specified - generating random weights.") | ||
state, state_mesh_annotations, _, _ = setup_initial_state( | ||
model, None, None, config, rng, mesh, checkpoint_manager, False | ||
) | ||
else: | ||
# Load params from checkpoint | ||
max_logging.log(f"Loading decode params from {config.load_parameters_path}") | ||
unboxed_abstract_state, state_mesh_annotations, _ = maxtext_utils.get_abstract_state(model, None, config, rng, mesh, False) | ||
with nn_partitioning.axis_rules(config.logical_axis_rules): | ||
params = checkpointing.load_params_from_path( | ||
config.load_parameters_path, unboxed_abstract_state.params, config.checkpoint_storage_concurrent_gb, config.checkpoint_storage_use_ocdbt, config.checkpoint_storage_use_zarr3 | ||
) | ||
state = maxtext_utils.init_decode_state(None, params) | ||
|
||
state = max_utils.unbox_logicallypartioned(state) | ||
return state, state_mesh_annotations | ||
|
||
|
||
|
||
def save_quantized_checkpoint_if_configured(config, params): | ||
assert config.quantization, "quantization must be configured" | ||
if config.save_quantized_params_path: | ||
checkpointing.save_params_to_path(config.save_quantized_params_path, params) | ||
else: | ||
"Skipping saving quantized checkpoint as save_quantized_params_path is null." | ||
|
||
|
||
def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): | ||
is_training = True | ||
return setup_initial_state( | ||
model, | ||
data_iterator, | ||
tx, | ||
config, | ||
rng, | ||
mesh, | ||
checkpoint_manager, | ||
is_training, | ||
) | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you have some inference folks look at this change @vipannalla or @mitalisi