Skip to content

Commit fee93c8

Browse files
authored
[Refactor] Update from single file (#6428)
* update * update * update * update * update * update * update * update * update * update * update' * update * update * update * update * update * update * up * update * update * update * update * update * update * update * update * update * update * update * update * up * update * update * update * update * update' * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * clean * update * update * clean up * clean up * update * clean * clean * update * updaet * clean up * fix docs * update * update * Revert "update" This reverts commit dbfb8f1. * update * update * update * update * fix controlnet * fix scheduler * fix controlnet tests
1 parent 5308cce commit fee93c8

22 files changed

+2074
-590
lines changed

docs/source/en/api/loaders/single_file.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta
3030

3131
## FromOriginalVAEMixin
3232

33-
[[autodoc]] loaders.single_file.FromOriginalVAEMixin
33+
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
3434

3535
## FromOriginalControlnetMixin
3636

37-
[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
37+
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin

src/diffusers/loaders/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ def text_encoder_attn_modules(text_encoder):
5454
_import_structure = {}
5555

5656
if is_torch_available():
57-
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
57+
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
58+
59+
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
5860
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
5961
_import_structure["utils"] = ["AttnProcsLayers"]
60-
6162
if is_transformers_available():
62-
_import_structure["single_file"].extend(["FromSingleFileMixin"])
63+
_import_structure["single_file"] = ["FromSingleFileMixin"]
6364
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
6465
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
6566
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -69,7 +70,8 @@ def text_encoder_attn_modules(text_encoder):
6970

7071
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
7172
if is_torch_available():
72-
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
73+
from .autoencoder import FromOriginalVAEMixin
74+
from .controlnet import FromOriginalControlNetMixin
7375
from .unet import UNet2DConditionLoadersMixin
7476
from .utils import AttnProcsLayers
7577

src/diffusers/loaders/autoencoder.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from huggingface_hub.utils import validate_hf_hub_args
16+
17+
from .single_file_utils import (
18+
create_diffusers_vae_model_from_ldm,
19+
fetch_ldm_config_and_checkpoint,
20+
)
21+
22+
23+
class FromOriginalVAEMixin:
24+
"""
25+
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
26+
"""
27+
28+
@classmethod
29+
@validate_hf_hub_args
30+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
31+
r"""
32+
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
33+
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
34+
35+
Parameters:
36+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
37+
Can be either:
38+
- A link to the `.ckpt` file (for example
39+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
40+
- A path to a *file* containing all pipeline weights.
41+
torch_dtype (`str` or `torch.dtype`, *optional*):
42+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
43+
dtype is automatically derived from the model's weights.
44+
force_download (`bool`, *optional*, defaults to `False`):
45+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
46+
cached versions if they exist.
47+
cache_dir (`Union[str, os.PathLike]`, *optional*):
48+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
49+
is not used.
50+
resume_download (`bool`, *optional*, defaults to `False`):
51+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
52+
incompletely downloaded files are deleted.
53+
proxies (`Dict[str, str]`, *optional*):
54+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
55+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
56+
local_files_only (`bool`, *optional*, defaults to `False`):
57+
Whether to only load local model weights and configuration files or not. If set to True, the model
58+
won't be downloaded from the Hub.
59+
token (`str` or *bool*, *optional*):
60+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
61+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
62+
revision (`str`, *optional*, defaults to `"main"`):
63+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
64+
allowed by Git.
65+
image_size (`int`, *optional*, defaults to 512):
66+
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
67+
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
68+
use_safetensors (`bool`, *optional*, defaults to `None`):
69+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
70+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
71+
weights. If set to `False`, safetensors weights are not loaded.
72+
kwargs (remaining dictionary of keyword arguments, *optional*):
73+
Can be used to overwrite load and saveable variables (for example the pipeline components of the
74+
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
75+
method. See example below for more information.
76+
77+
<Tip warning={true}>
78+
79+
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
80+
a VAE from SDXL or a Stable Diffusion v2 model or higher.
81+
82+
</Tip>
83+
84+
Examples:
85+
86+
```py
87+
from diffusers import AutoencoderKL
88+
89+
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
90+
model = AutoencoderKL.from_single_file(url)
91+
```
92+
"""
93+
94+
original_config_file = kwargs.pop("original_config_file", None)
95+
resume_download = kwargs.pop("resume_download", False)
96+
force_download = kwargs.pop("force_download", False)
97+
proxies = kwargs.pop("proxies", None)
98+
token = kwargs.pop("token", None)
99+
cache_dir = kwargs.pop("cache_dir", None)
100+
local_files_only = kwargs.pop("local_files_only", None)
101+
revision = kwargs.pop("revision", None)
102+
torch_dtype = kwargs.pop("torch_dtype", None)
103+
use_safetensors = kwargs.pop("use_safetensors", True)
104+
105+
class_name = cls.__name__
106+
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
107+
pretrained_model_link_or_path=pretrained_model_link_or_path,
108+
class_name=class_name,
109+
original_config_file=original_config_file,
110+
resume_download=resume_download,
111+
force_download=force_download,
112+
proxies=proxies,
113+
token=token,
114+
revision=revision,
115+
local_files_only=local_files_only,
116+
use_safetensors=use_safetensors,
117+
cache_dir=cache_dir,
118+
)
119+
120+
image_size = kwargs.pop("image_size", None)
121+
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
122+
vae = component["vae"]
123+
if torch_dtype is not None:
124+
vae = vae.to(torch_dtype)
125+
126+
return vae

src/diffusers/loaders/controlnet.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from huggingface_hub.utils import validate_hf_hub_args
16+
17+
from .single_file_utils import (
18+
create_diffusers_controlnet_model_from_ldm,
19+
fetch_ldm_config_and_checkpoint,
20+
)
21+
22+
23+
class FromOriginalControlNetMixin:
24+
"""
25+
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
26+
"""
27+
28+
@classmethod
29+
@validate_hf_hub_args
30+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
31+
r"""
32+
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
33+
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
34+
35+
Parameters:
36+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
37+
Can be either:
38+
- A link to the `.ckpt` file (for example
39+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
40+
- A path to a *file* containing all pipeline weights.
41+
torch_dtype (`str` or `torch.dtype`, *optional*):
42+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
43+
dtype is automatically derived from the model's weights.
44+
force_download (`bool`, *optional*, defaults to `False`):
45+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
46+
cached versions if they exist.
47+
cache_dir (`Union[str, os.PathLike]`, *optional*):
48+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
49+
is not used.
50+
resume_download (`bool`, *optional*, defaults to `False`):
51+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
52+
incompletely downloaded files are deleted.
53+
proxies (`Dict[str, str]`, *optional*):
54+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
55+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
56+
local_files_only (`bool`, *optional*, defaults to `False`):
57+
Whether to only load local model weights and configuration files or not. If set to True, the model
58+
won't be downloaded from the Hub.
59+
token (`str` or *bool*, *optional*):
60+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
61+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
62+
revision (`str`, *optional*, defaults to `"main"`):
63+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
64+
allowed by Git.
65+
use_safetensors (`bool`, *optional*, defaults to `None`):
66+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
67+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
68+
weights. If set to `False`, safetensors weights are not loaded.
69+
image_size (`int`, *optional*, defaults to 512):
70+
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
71+
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
72+
upcast_attention (`bool`, *optional*, defaults to `None`):
73+
Whether the attention computation should always be upcasted.
74+
kwargs (remaining dictionary of keyword arguments, *optional*):
75+
Can be used to overwrite load and saveable variables (for example the pipeline components of the
76+
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
77+
method. See example below for more information.
78+
79+
Examples:
80+
81+
```py
82+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
83+
84+
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
85+
model = ControlNetModel.from_single_file(url)
86+
87+
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
88+
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
89+
```
90+
"""
91+
original_config_file = kwargs.pop("original_config_file", None)
92+
resume_download = kwargs.pop("resume_download", False)
93+
force_download = kwargs.pop("force_download", False)
94+
proxies = kwargs.pop("proxies", None)
95+
token = kwargs.pop("token", None)
96+
cache_dir = kwargs.pop("cache_dir", None)
97+
local_files_only = kwargs.pop("local_files_only", None)
98+
revision = kwargs.pop("revision", None)
99+
torch_dtype = kwargs.pop("torch_dtype", None)
100+
use_safetensors = kwargs.pop("use_safetensors", True)
101+
102+
class_name = cls.__name__
103+
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
104+
pretrained_model_link_or_path=pretrained_model_link_or_path,
105+
class_name=class_name,
106+
original_config_file=original_config_file,
107+
resume_download=resume_download,
108+
force_download=force_download,
109+
proxies=proxies,
110+
token=token,
111+
revision=revision,
112+
local_files_only=local_files_only,
113+
use_safetensors=use_safetensors,
114+
cache_dir=cache_dir,
115+
)
116+
117+
upcast_attention = kwargs.pop("upcast_attention", False)
118+
image_size = kwargs.pop("image_size", None)
119+
120+
component = create_diffusers_controlnet_model_from_ldm(
121+
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
122+
)
123+
controlnet = component["controlnet"]
124+
if torch_dtype is not None:
125+
controlnet = controlnet.to(torch_dtype)
126+
127+
return controlnet

0 commit comments

Comments
 (0)