-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[core] respect local_files_only=True
when using sharded checkpoints
#12005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This reverts commit 8d431dc.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -403,9 +404,26 @@ def _get_checkpoint_shard_files( | |||
|
|||
ignore_patterns = ["*.json", "*.md"] | |||
# `model_info` call must guarded with the above condition. | |||
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the purpose of this check is to verify if the necessary sharded files are present in the model repo before attempting a download, presumably to avoid a large download if all files aren't present. If we cannot connect to the hub, we just have to assume the necessary shard files are already present locally.
I think we can just skip this check if local_files_only=True
and then check if all the shard filenames are present in the cached_folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just this is sufficient
if not local_files_only:
# run model_info check
Run snapshot download
Then after the cached_filenames is created, iterate over the files to verify they exist
for filename in cached_filename:
if not if not os.path.exists(filename):
raise EnvironmentError("expected file not present in {cached_folder}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We don't have to run
snapshot_download()
whenlocal_files_only=False
, that might be unnecessary. - Why run
snapshot_download()
after also runningmodel_info()
? - Even if we run
snapshot_download()
regardless oflocal_files_only
var, I think we should have it insidetry-except
in case the endpoint cannot be pinged for some reason and raise theConnectionError
as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See if b7af511 resolves this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see what you mean. Let me update. Sorry about the back and forth.
@DN6 see if the latest changes work for you. |
What does this PR do?
See: #11948