-
Notifications
You must be signed in to change notification settings - Fork 369
Multi-tier checkpointing + orbax replicator #1332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
||
replicator_yaml = f"""job-name: {run_name} | ||
framework: orbax | ||
assume-data-parallelism: 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DmitryKakurin how do we set this correctly in the following scenarios?
- ici_dp=8, dcn_dp=4, should it be 32?
- ici_dp=1 dcn_dp=4
seems in maxtext it's hardcoded as num_slices which isn't always the case in axlearn: https://github.com/AI-Hypercomputer/maxtext/blob/946f2a60c2c05f8fe8540fefb7e44b3838f87eb3/MaxText/max_utils.py#L306
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should come as a config param in Config
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still unclear what to use when ici_dp is bigger than 1
@@ -50,7 +44,7 @@ def setup(spec: str): | |||
parsed_args[k] = v | |||
if "local_ckpt_dir" not in parsed_args: | |||
raise ValueError("local_ckpt_dir must be specified.") | |||
# Get process ID and IP of coordinator | |||
# Get process ID and IP of jax coordinator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change "Get process ID" -> "Get our process ID"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "our" mean? Can we be more specific?
|
||
replicator_yaml = f"""job-name: {run_name} | ||
framework: orbax | ||
assume-data-parallelism: 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should come as a config param in Config
class.
@@ -133,6 +135,8 @@ class Config(BaseCheckpointer.Config): | |||
async_timeout_secs: Timeout for async barrier in seconds when saving tensors. | |||
""" | |||
|
|||
assume_data_parallelism: int = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be no default for this one - it's not guessable. And if it's not set correctly things will not go well.
Opening as a very rough draft PR to collect initial feedback
Checkpoint restores for both local checkpoints and GCS backups are currently working, but some sections for non-tensor management are currently commented out and this will need to be integrated back in