Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for fetching variants only #10646

Merged
merged 19 commits into from
Mar 10, 2025
Merged

Fix for fetching variants only #10646

merged 19 commits into from
Mar 10, 2025

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Jan 24, 2025

What does this PR do?

With PR: #9869 we fixed downloading sharded variants only and mixed variants, however we missed the case where a component might have sharded non-variant files and non-sharded variant file, which is the case with issue: #10634

This PR:

  1. Adds a simpler has_variant check and removes the previous checks. This check first checks to see if we are in a component folder and then checks if any variants exist within the folder. If no folder exists then skip trying to add the additional non variant file.
  2. Adds additional tests to check for the condition mentioned in issue The huggingface repo need to be fixed for Sana 2K and 4K models #10634.

Since usable_filenames is always populated with variants, we capture the necessary variant files and what we're trying to avoid is extra file downloads.

The only edge case I can think of here where this would fail (which passes with the current implementation) is if the filenames are the following:

Non-Variant in the main dir and a variant in a subfolder. Although I think this an edge case that we probably can't load anyway? I can't think of any pipelines that would have this configuration.

filenames = ["diffusion_pytorch_model.safetensors", f"unet/diffusion_pytorch_model.{variant}.safetensors"]

Fixes # (issue)
#10634

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 requested a review from yiyixuxu January 24, 2025 19:38
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@lawrence-cj
Copy link
Contributor

Thanks @DN6 for supporting!

for filename in non_variant_filenames:
if convert_to_variant(filename) in variant_filenames:
continue
return any(f.startswith(component) for f in variant_filenames)
Copy link
Collaborator

@yiyixuxu yiyixuxu Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happends if like we only have a bf16.bin and this is a non-variant safetensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in we are trying to fetch something like this?

        variant = "fp16"
        filenames = [
                f"vae/diffusion_pytorch_model.{variant}.bin",
                f"text_encoder/model.{variant}.bin",
                f"unet/diffusion_pytorch_model.{variant}.bin",
        ]
        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like this, I think we should fetch the non-variant safetensors in this case, no?

    variant = "fp16"
    filenames = [
            f"vae/diffusion_pytorch_model.{variant}.bin",
            f"text_encoder/model.{variant}.bin",
            f"unet/diffusion_pytorch_model.{variant}.bin",
            f"vae/diffusion_pytorch_model.safetensors",
            f"text_encoder/model.safetensors",
            f"unet/diffusion_pytorch_model.safetensors",
    ]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm currently the behaviour on main is to return all the files in that list (both bin and safetensors) as usable_filenames and I think the ignore patterns would remove the bin files, resulting in just the safetensors being downloaded.

With this change only the fp16.bin files would be downloaded. Which feels technically "correct" to me since they are the "variant" files of each component. IMO non-variants should only be downloaded if no variant exists (regardless of format)

But this case implies that the proposal here is a breaking change, so I'll update to account for it.

@DN6
Copy link
Collaborator Author

DN6 commented Jan 29, 2025

@yiyixuxu Had to do a bit more of a refactor to account for safetensors prioritization. But it should be much more robust to handle any number of repo file combinations. I've added a test for your case and a few others as well. I think they should cover all likely repo file layout scenarios, but if there are others I may have missed lmk.

@nitinmukesh
Copy link

nitinmukesh commented Feb 1, 2025

@DN6

Here is another test case with fp16, Failing after using 2 updated files from
https://github.com/huggingface/diffusers/pull/10646/files

import torch
from diffusers.utils import load_image
from diffusers import AuraFlowPipeline

pipe = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    variant="fp16",
    torch_dtype=torch.float16,
    use_safetensors=True,
)

Repo
https://huggingface.co/fal/AuraFlow-v0.3/tree/main

Warning: The repository contains sharded checkpoints for variant 'fp16' maybe in a deprecated format. Please check your files carefully:

- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors
- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors

If you find any files in the deprecated format:
1. Remove all existing checkpoint files for this variant.
2. Re-obtain the correct files by running `save_pretrained()`.

This will ensure you're using the most up-to-date and compatible checkpoint format.
model_index.json: 100%|██████████████████████████████████████████████████| 458/458 [00:00<?, ?B/s]


A mixture of fp16 and non-fp16 filenames will be loaded.
Loaded fp16 filenames:
[vae/diffusion_pytorch_model.fp16.safetensors, text_encoder/model.fp16.safetensors]
Loaded non-fp16 filenames:
[transformer/diffusion_pytorch_model-00001-of-00003.safetensors, transformer/diffusion_pytorch_model-00002-of-00003.safetensors, transformer/diffusion_pytorch_model-00003-of-00003.safetensors
If this behavior is not expected, please check your folder structure.
tokenizer.model: 100%|█████████████████████████████████████████| 500k/500k [00:00<00:00, 10.5MB/s]
tokenizer/tokenizer_config.json: 100%|███████████████████████| 21.0k/21.0k [00:00<00:00, 19.1MB/s]
tokenizer/added_tokens.json: 100%|████████████████████████████| 2.59k/2.59k [00:00<00:00, 573kB/s]
scheduler/scheduler_config.json: 100%|███████████████████████████████████| 142/142 [00:00<?, ?B/s]
tokenizer/special_tokens_map.json: 100%|█████████████████████████████| 2.68k/2.68k [00:00<?, ?B/s]
text_encoder/config.json: 100%|███████████████████████████████████| 949/949 [00:00<00:00, 951kB/s]
tokenizer/tokenizer.json: 100%|██████████████████████████████| 1.86M/1.86M [00:00<00:00, 7.64MB/s]
transformer/config.json: 100%|████████████████████████████████████| 379/379 [00:00<00:00, 376kB/s]
vae/config.json: 100%|███████████████████████████████████████████████████| 859/859 [00:00<?, ?B/s]
(…)ion_pytorch_model.safetensors.index.json: 100%|███████████| 36.0k/36.0k [00:00<00:00, 9.04MB/s]
diffusion_pytorch_model.fp16.safetensors: 100%|████████████████| 167M/167M [01:10<00:00, 2.36MB/s]
model.fp16.safetensors:   4%|█▍                               | 126M/2.95G [01:08<25:54, 1.82MB/s]
model.fp16.safetensors:   5%|█▊                               | 157M/2.95G [01:20<20:22, 2.28MB/s]
(…)pytorch_model-00001-of-00003.safetensors:   2%|▎           | 241M/9.99G [01:07<47:46, 3.40MB/s]
(…)pytorch_model-00001-of-00003.safetensors:   3%|▎           | 294M/9.99G [01:21<41:42, 3.88MB/s]
(…)pytorch_model-00002-of-00003.safetensors:   2%|▏         | 178M/9.89G [01:20<1:09:08, 2.34MB/s]
(…)pytorch_model-00003-of-00003.safetensors:   4%|▍           | 273M/7.55G [01:22<31:42, 3.83MB/s]

image

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 5, 2025

it seems we are changing the behavior of variant_compatible_siblings a bit
currently, variant_compatible_siblings does not consider extensions, we address that in the download step

from_flax = kwargs.pop("from_flax", False)
(related arguments are from_flax, use_onnx and use_safetensors)

so my questions is:

  1. is it more efficient to move the extension filter outside of download and earlier in the process
  2. if so, should we apply it too all extensions? is there a reason to single out safetensors?

@DN6
Copy link
Collaborator Author

DN6 commented Feb 6, 2025

is it more efficient to move the extension filter outside of download and earlier in the process

Yes, I think it would be better to add an extension filter/check earlier and then filter variants after (without extensions checks).

The current behaviour of decoupling the steps is a bit confusing and makes it harder to reason about which files end up getting downloaded. For instance in this case:

        filenames = [
            "text_encoder/model.bin",
            "unet/diffusion_pytorch_model.bin",
            "unet/diffusion_pytorch_model.safetensors",
        ]

variant_compatible_siblings will currently return all the files in this list, and the is_safetensors_compatible check will return False because not all components have safetensors. Which means bin files won't be filtered out in_get_ignore_patterns and all the unet files in this case will be downloaded. While we could consider this an edge case, I think it illustrates the point that the current file filtering mechanism is not easy to follow.

I think the core issue is that we apply checks/filters currently over the full list of filenames while we should actually be checking per component in order to support mixed downloading properly, as well as things like downloading safetensors over bin files, non-sharded-variants for sharded-non-variants etc.

We can either add additional checks in variant_compatible_siblings or refactor download a bit to make it more clear exactly which files will be used. I can update this PR to do that if that sounds good to you?

if so, should we apply it too all extensions? is there a reason to single out safetensors?

It was an oversight on my part. I was focusing on solving the use case you pointed out and forgot about the use_onnx and from_flax flags.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 6, 2025

We can either add additional checks in variant_compatible_siblings or refactor download a bit to make it more clear exactly which files will be used. I can update this PR to do that if that sounds good to you?

both approach sounds good to me:)

@@ -920,18 +956,13 @@ def _get_custom_components_and_folders(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this in this function. The variant_filenames are only being used to raise this error. It's better we consolidate it into variant_compatible_siblings

return custom_components, folder_names


def _get_ignore_patterns(
passed_components,
model_folder_names: List[str],
model_filenames: List[str],
variant_filenames: List[str],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant_filenames is only being used to raise a warning here. We can consolidate this into variant_compatible_siblings and raise the warning there.

@DN6
Copy link
Collaborator Author

DN6 commented Feb 27, 2025

@yiyixuxu This is ready for another review. It should be able to handle all file format types. I made some changes in pipeline.download so that ignore_patterns is created earlier. We then leverage that to filter out files and variants.

@DN6 DN6 requested a review from yiyixuxu March 3, 2025 07:09
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @DN6
I think it is a real nice refactor!!! (and really tricky too)

I left some small feedback on the PR

if component_variants:
variant_filenames.update(
component_variants | component_variant_index_files
if component_variants
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this if is redundant here
how we intend to do with the legacy_variant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh missed this. Great catch! Updated to account for this.

else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

@DN6 DN6 added the roadmap Add to current release roadmap label Mar 6, 2025
@DN6 DN6 merged commit 9a1810f into main Mar 10, 2025
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Development

Successfully merging this pull request may close these issues.

5 participants