-
Notifications
You must be signed in to change notification settings - Fork 367
Moving the checkpointing util funtions to a spearate file to avoid circular dependency #1623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @wang2yn84 for the nice refactor. Was the circular dependency causing some type of error? Or is the change for cleaner code?
@@ -52,6 +52,7 @@ | |||
import jax | |||
import gc | |||
from MaxText import max_logging | |||
from MaxText import maxtext_checkpointing_utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like from MaxText import max_utils
is no longer needed. Can you remove it from this file and other ones where it isn't needed either?
) | ||
state = maxtext_utils.init_decode_state(None, params) | ||
|
||
state = max_utils.unbox_logicallypartioned(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change trying to get rid of the max_utils dependency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change is to move all the util functions that depends on MaxText.checkpointing to this file so that other files can depends on maxtext_utils without circular dependency issue. maxtext_checkpointing_utils can depend on max_utils and maxtext_utils without any issue.
Hi Branden, in the current implementation, we have to carefully avoid depend on maxtext_utils from multihost_dataloading. There are couple of util functions in multihost_dataloading we should place in maxtext_utils but we can't. |
"Skipping saving quantized checkpoint as save_quantized_params_path is null." | ||
|
||
|
||
def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this method fits in this file - this define the training state, its needed even without checkpointing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chatted offline, agreed to rename this new util file to maxtext_state_initializion_utils.py.
return state, state_mesh_annotations, state_mesh_shardings, data_iterator | ||
|
||
|
||
def setup_decode_state(model, config, rng, mesh, checkpoint_manager): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a checkpointing method, its needed even when checkpointing is turned off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager | ||
|
||
|
||
def setup_initial_state( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not a checkpointing function, its needed even when checkpointing is turned off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
fa152d3
to
f3a0a49
Compare
|
||
|
||
|
||
def save_quantized_checkpoint_if_configured(config, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you have some inference folks look at this change @vipannalla or @mitalisi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a simple split, not new functionality. LGTM, did you do a end-2-end run for inference to confirm nothing is broken?
ba55551
to
83950d0
Compare
…rcular dependency.
83950d0
to
d34944c
Compare
Description
The circular dependency is multihost_dataloading -> maxtext_utils -> checkpointing -> multihost_dataloading. That's why in the previous development, we removed the deps of maxtext_utils from multihost_dataloading to avoid the issue, but that's just temporarily. Moving all the maxtext_util functions that depends on checkpointing to a separate util function solves this issue.
Tests
Presubmit tests.
Checklist
Before submitting this PR, please make sure (put X in square brackets):