-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for fetching variants only #10646
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks @DN6 for supporting! |
for filename in non_variant_filenames: | ||
if convert_to_variant(filename) in variant_filenames: | ||
continue | ||
return any(f.startswith(component) for f in variant_filenames) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happends if like we only have a bf16.bin and this is a non-variant safetensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in we are trying to fetch something like this?
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.bin",
f"text_encoder/model.{variant}.bin",
f"unet/diffusion_pytorch_model.{variant}.bin",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like this, I think we should fetch the non-variant safetensors in this case, no?
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.bin",
f"text_encoder/model.{variant}.bin",
f"unet/diffusion_pytorch_model.{variant}.bin",
f"vae/diffusion_pytorch_model.safetensors",
f"text_encoder/model.safetensors",
f"unet/diffusion_pytorch_model.safetensors",
]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm currently the behaviour on main is to return all the files in that list (both bin and safetensors) as usable_filenames
and I think the ignore patterns would remove the bin files, resulting in just the safetensors being downloaded.
With this change only the fp16.bin files would be downloaded. Which feels technically "correct" to me since they are the "variant" files of each component. IMO non-variants should only be downloaded if no variant exists (regardless of format)
But this case implies that the proposal here is a breaking change, so I'll update to account for it.
@yiyixuxu Had to do a bit more of a refactor to account for safetensors prioritization. But it should be much more robust to handle any number of repo file combinations. I've added a test for your case and a few others as well. I think they should cover all likely repo file layout scenarios, but if there are others I may have missed lmk. |
Here is another test case with fp16, Failing after using 2 updated files from
Repo
|
it seems we are changing the behavior of
from_flax , use_onnx and use_safetensors )
so my questions is:
|
Yes, I think it would be better to add an extension filter/check earlier and then filter variants after (without extensions checks). The current behaviour of decoupling the steps is a bit confusing and makes it harder to reason about which files end up getting downloaded. For instance in this case:
I think the core issue is that we apply checks/filters currently over the full list of filenames while we should actually be checking per component in order to support mixed downloading properly, as well as things like downloading safetensors over bin files, non-sharded-variants for sharded-non-variants etc. We can either add additional checks in
It was an oversight on my part. I was focusing on solving the use case you pointed out and forgot about the use_onnx and from_flax flags. |
both approach sounds good to me:) |
@@ -920,18 +956,13 @@ def _get_custom_components_and_folders( | |||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." | |||
) | |||
|
|||
if len(variant_filenames) == 0 and variant is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this in this function. The variant_filenames
are only being used to raise this error. It's better we consolidate it into variant_compatible_siblings
return custom_components, folder_names | ||
|
||
|
||
def _get_ignore_patterns( | ||
passed_components, | ||
model_folder_names: List[str], | ||
model_filenames: List[str], | ||
variant_filenames: List[str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
variant_filenames
is only being used to raise a warning here. We can consolidate this into variant_compatible_siblings
and raise the warning there.
@yiyixuxu This is ready for another review. It should be able to handle all file format types. I made some changes in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @DN6
I think it is a real nice refactor!!! (and really tricky too)
I left some small feedback on the PR
if component_variants: | ||
variant_filenames.update( | ||
component_variants | component_variant_index_files | ||
if component_variants |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this if
is redundant here
how we intend to do with the legacy_variant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh missed this. Great catch! Updated to account for this.
else: | ||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" | ||
return variant_filename | ||
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome!
What does this PR do?
With PR: #9869 we fixed downloading sharded variants only and mixed variants, however we missed the case where a component might have sharded non-variant files and non-sharded variant file, which is the case with issue: #10634
This PR:
has_variant
check and removes the previous checks. This check first checks to see if we are in a component folder and then checks if any variants exist within the folder. If no folder exists then skip trying to add the additional non variant file.Since
usable_filenames
is always populated with variants, we capture the necessary variant files and what we're trying to avoid is extra file downloads.The only edge case I can think of here where this would fail (which passes with the current implementation) is if the filenames are the following:
Non-Variant in the main dir and a variant in a subfolder. Although I think this an edge case that we probably can't load anyway? I can't think of any pipelines that would have this configuration.
Fixes # (issue)
#10634
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.