-
Notifications
You must be signed in to change notification settings - Fork 6.1k
The Modular Diffusers #9672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
The Modular Diffusers #9672
Changes from all commits
33f85fa
52a7f1c
e8d0980
ad3f9a2
ddea157
af9572d
2b6dcbf
70272b1
46ec174
f1b3036
540d303
6742f16
005195c
024a9f5
37e8dc7
8b811fe
c70a285
ffc2992
ace53e2
a8df0f1
e50d614
bc3d1c9
2b3cd2d
b305c77
0b90051
806e8e6
4fa85c7
72d9a81
10d4a77
27dde51
8c02572
a09ca7f
ed59f90
72c5bf0
6c93626
1d63306
2e0f5c8
c12a05b
54f410d
6985906
db94ca8
e973de6
7a34832
2220af6
fb78f4f
0966663
7f897a9
a6804de
7007f72
a226920
77b5fa5
6e2fe26
68a5185
d046cf7
71df158
b3fb418
00cae4e
ccb35ac
00a3bc9
4bed3e3
c7020df
2c3e4ea
e5089d7
8ddb20b
cff0fd6
485f8d1
addaad0
12650e1
96795af
6a509ba
a8e853b
7ad01a6
5a8c1b5
8913d59
45392cc
9e58856
04c16d0
083479c
4751d45
d12531d
19545fd
78d2454
085ade0
42c06e9
1ae591e
bb40443
7c78fb1
48e4ff5
e49413d
ffbaa89
cdaaa40
1c9f0a8
174628e
c0327e4
5917d70
8c038f0
cb328d3
7d2a633
74b908b
9530245
c437ae7
f3453f0
a82e211
a33206d
75e6238
129d658
da4242d
ab6d634
7492e33
b92cda2
61772f0
9abac85
84f4b27
449f299
7608d2e
f63d62e
655512e
885a596
b543bcc
75540f4
93760b1
9aaec5b
58dbe0c
49ea4d1
92b6b43
8c680bc
fedaa00
fdd2bed
3a3441c
9fae382
b43e703
c75b88f
285f877
f09b1cc
c5849ba
363737e
bbd9340
0138e17
db4b54c
abf28d5
4b12a60
f27fbce
98ea5c9
b5db8aa
4543d21
1987c07
2e20241
13fe248
8cb5b08
3e46c86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,134 @@ | ||||||
# Copyright 2025 The HuggingFace Team. All rights reserved. | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
""" | ||||||
Usage example: | ||||||
TODO | ||||||
""" | ||||||
|
||||||
import ast | ||||||
import importlib.util | ||||||
import os | ||||||
from argparse import ArgumentParser, Namespace | ||||||
from pathlib import Path | ||||||
|
||||||
from ..utils import logging | ||||||
from . import BaseDiffusersCLICommand | ||||||
|
||||||
|
||||||
EXPECTED_PARENT_CLASSES = ["PipelineBlock"] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
CONFIG = "config.json" | ||||||
|
||||||
|
||||||
def conversion_command_factory(args: Namespace): | ||||||
return CustomBlocksCommand(args.block_module_name, args.block_class_name) | ||||||
|
||||||
|
||||||
class CustomBlocksCommand(BaseDiffusersCLICommand): | ||||||
@staticmethod | ||||||
def register_subcommand(parser: ArgumentParser): | ||||||
conversion_parser = parser.add_parser("custom_blocks") | ||||||
conversion_parser.add_argument( | ||||||
"--block_module_name", | ||||||
type=str, | ||||||
default="block.py", | ||||||
help="Module filename in which the custom block will be implemented.", | ||||||
) | ||||||
conversion_parser.add_argument( | ||||||
"--block_class_name", | ||||||
type=str, | ||||||
default=None, | ||||||
help="Name of the custom block. If provided None, we will try to infer it.", | ||||||
) | ||||||
conversion_parser.set_defaults(func=conversion_command_factory) | ||||||
|
||||||
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None): | ||||||
self.logger = logging.get_logger("diffusers-cli/custom_blocks") | ||||||
self.block_module_name = Path(block_module_name) | ||||||
self.block_class_name = block_class_name | ||||||
|
||||||
def run(self): | ||||||
# determine the block to be saved. | ||||||
out = self._get_class_names(self.block_module_name) | ||||||
classes_found = list({cls for cls, _ in out}) | ||||||
|
||||||
if self.block_class_name is not None: | ||||||
child_class, parent_class = self._choose_block(out, self.block_class_name) | ||||||
if child_class is None and parent_class is None: | ||||||
raise ValueError( | ||||||
"`block_class_name` could not be retrieved. Available classes from " | ||||||
f"{self.block_module_name}:\n{classes_found}" | ||||||
) | ||||||
else: | ||||||
self.logger.info( | ||||||
f"Found classes: {classes_found} will be using {classes_found[0]}. " | ||||||
"If this needs to be changed, re-run the command specifying `block_class_name`." | ||||||
) | ||||||
child_class, parent_class = out[0][0], out[0][1] | ||||||
|
||||||
# dynamically get the custom block and initialize it to call `save_pretrained` in the current directory. | ||||||
# the user is responsible for running it, so I guess that is safe? | ||||||
module_name = f"__dynamic__{self.block_module_name.stem}" | ||||||
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name)) | ||||||
module = importlib.util.module_from_spec(spec) | ||||||
spec.loader.exec_module(module) | ||||||
getattr(module, child_class)().save_pretrained(os.getcwd()) | ||||||
|
||||||
# or, we could create it manually. | ||||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class) | ||||||
# with open(CONFIG, "w") as f: | ||||||
# json.dump(automap, f) | ||||||
with open("requirements.txt", "w") as f: | ||||||
f.write("") | ||||||
|
||||||
def _choose_block(self, candidates, chosen=None): | ||||||
for cls, base in candidates: | ||||||
if cls == chosen: | ||||||
return cls, base | ||||||
return None, None | ||||||
|
||||||
def _get_class_names(self, file_path): | ||||||
source = file_path.read_text(encoding="utf-8") | ||||||
try: | ||||||
tree = ast.parse(source, filename=file_path) | ||||||
except SyntaxError as e: | ||||||
raise ValueError(f"Could not parse {file_path!r}: {e}") from e | ||||||
|
||||||
results: list[tuple[str, str]] = [] | ||||||
for node in tree.body: | ||||||
if not isinstance(node, ast.ClassDef): | ||||||
continue | ||||||
|
||||||
# extract all base names for this class | ||||||
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None] | ||||||
|
||||||
# for each allowed base that appears in the class's bases, emit a tuple | ||||||
for allowed in EXPECTED_PARENT_CLASSES: | ||||||
if allowed in base_names: | ||||||
results.append((node.name, allowed)) | ||||||
|
||||||
return results | ||||||
|
||||||
def _get_base_name(self, node: ast.expr): | ||||||
if isinstance(node, ast.Name): | ||||||
return node.id | ||||||
elif isinstance(node, ast.Attribute): | ||||||
val = self._get_base_name(node.value) | ||||||
return f"{val}.{node.attr}" if val else node.attr | ||||||
return None | ||||||
|
||||||
def _create_automap(self, parent_class, child_class): | ||||||
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] | ||||||
auto_map = {f"{parent_class}": f"{module}.{child_class}"} | ||||||
return {"auto_map": auto_map} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2025 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Union | ||
|
||
from ..utils import is_torch_available | ||
|
||
|
||
if is_torch_available(): | ||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance | ||
from .auto_guidance import AutoGuidance | ||
from .classifier_free_guidance import ClassifierFreeGuidance | ||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance | ||
from .skip_layer_guidance import SkipLayerGuidance | ||
from .smoothed_energy_guidance import SmoothedEnergyGuidance | ||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance | ||
|
||
GuiderType = Union[ | ||
AdaptiveProjectedGuidance, | ||
AutoGuidance, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @a-r-r-o-w can we make sure PAG has its own class before we merge? |
||
ClassifierFreeGuidance, | ||
ClassifierFreeZeroStarGuidance, | ||
SkipLayerGuidance, | ||
SmoothedEnergyGuidance, | ||
TangentialClassifierFreeGuidance, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 @sayakpaul
I merged in this part of code without reviewing it (I need the
save_pretrained()
code and it works fine)can you let me know if you want to keep it in here when we merge this PR or remove &convert it into a new PR