Skip to content

Multi-tier checkpointing + orbax replicator #1332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

ehorning
Copy link
Contributor

Opening as a very rough draft PR to collect initial feedback

Checkpoint restores for both local checkpoints and GCS backups are currently working, but some sections for non-tensor management are currently commented out and this will need to be integrated back in


replicator_yaml = f"""job-name: {run_name}
framework: orbax
assume-data-parallelism: 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DmitryKakurin how do we set this correctly in the following scenarios?

  1. ici_dp=8, dcn_dp=4, should it be 32?
  2. ici_dp=1 dcn_dp=4

seems in maxtext it's hardcoded as num_slices which isn't always the case in axlearn: https://github.com/AI-Hypercomputer/maxtext/blob/946f2a60c2c05f8fe8540fefb7e44b3838f87eb3/MaxText/max_utils.py#L306

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should come as a config param in Config class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still unclear what to use when ici_dp is bigger than 1

@@ -50,7 +44,7 @@ def setup(spec: str):
parsed_args[k] = v
if "local_ckpt_dir" not in parsed_args:
raise ValueError("local_ckpt_dir must be specified.")
# Get process ID and IP of coordinator
# Get process ID and IP of jax coordinator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change "Get process ID" -> "Get our process ID"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "our" mean? Can we be more specific?


replicator_yaml = f"""job-name: {run_name}
framework: orbax
assume-data-parallelism: 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should come as a config param in Config class.

@@ -133,6 +135,8 @@ class Config(BaseCheckpointer.Config):
async_timeout_secs: Timeout for async barrier in seconds when saving tensors.
"""

assume_data_parallelism: int = 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no default for this one - it's not guessable. And if it's not set correctly things will not go well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants