Skip to content

Moving the checkpointing util funtions to a spearate file to avoid circular dependency #1623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wang2yn84
Copy link
Collaborator

@wang2yn84 wang2yn84 commented Apr 24, 2025

Description

The circular dependency is multihost_dataloading -> maxtext_utils -> checkpointing -> multihost_dataloading. That's why in the previous development, we removed the deps of maxtext_utils from multihost_dataloading to avoid the issue, but that's just temporarily. Moving all the maxtext_util functions that depends on checkpointing to a separate util function solves this issue.

Tests

Presubmit tests.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wang2yn84 for the nice refactor. Was the circular dependency causing some type of error? Or is the change for cleaner code?

@@ -52,6 +52,7 @@
import jax
import gc
from MaxText import max_logging
from MaxText import maxtext_checkpointing_utils
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like from MaxText import max_utils is no longer needed. Can you remove it from this file and other ones where it isn't needed either?

)
state = maxtext_utils.init_decode_state(None, params)

state = max_utils.unbox_logicallypartioned(state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change trying to get rid of the max_utils dependency?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change is to move all the util functions that depends on MaxText.checkpointing to this file so that other files can depends on maxtext_utils without circular dependency issue. maxtext_checkpointing_utils can depend on max_utils and maxtext_utils without any issue.

@wang2yn84
Copy link
Collaborator Author

Thanks @wang2yn84 for the nice refactor. Was the circular dependency causing some type of error? Or is the change for cleaner code?

Hi Branden, in the current implementation, we have to carefully avoid depend on maxtext_utils from multihost_dataloading. There are couple of util functions in multihost_dataloading we should place in maxtext_utils but we can't.

"Skipping saving quantized checkpoint as save_quantized_params_path is null."


def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this method fits in this file - this define the training state, its needed even without checkpointing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted offline, agreed to rename this new util file to maxtext_state_initializion_utils.py.

return state, state_mesh_annotations, state_mesh_shardings, data_iterator


def setup_decode_state(model, config, rng, mesh, checkpoint_manager):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a checkpointing method, its needed even when checkpointing is turned off

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager


def setup_initial_state(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a checkpointing function, its needed even when checkpointing is turned off

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

@wang2yn84 wang2yn84 force-pushed the lance-refactor branch 3 times, most recently from fa152d3 to f3a0a49 Compare April 25, 2025 23:33



def save_quantized_checkpoint_if_configured(config, params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you have some inference folks look at this change @vipannalla or @mitalisi

Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a simple split, not new functionality. LGTM, did you do a end-2-end run for inference to confirm nothing is broken?

@wang2yn84 wang2yn84 force-pushed the lance-refactor branch 2 times, most recently from ba55551 to 83950d0 Compare April 30, 2025 18:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants