Skip to content

Commit c70a285

Browse files
committed
style
1 parent 8b811fe commit c70a285

File tree

5 files changed

+55
-81
lines changed

5 files changed

+55
-81
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,12 @@
145145
"DDIMPipeline",
146146
"DDPMPipeline",
147147
"DiffusionPipeline",
148-
"ModularPipelineBuilder",
149148
"DiTPipeline",
150149
"ImagePipelineOutput",
151150
"KarrasVePipeline",
152151
"LDMPipeline",
153152
"LDMSuperResolutionPipeline",
153+
"ModularPipelineBuilder",
154154
"PNDMPipeline",
155155
"RePaintPipeline",
156156
"ScoreSdeVePipeline",
@@ -366,11 +366,11 @@
366366
"StableDiffusionXLImg2ImgPipeline",
367367
"StableDiffusionXLInpaintPipeline",
368368
"StableDiffusionXLInstructPix2PixPipeline",
369+
"StableDiffusionXLModularPipeline",
369370
"StableDiffusionXLPAGImg2ImgPipeline",
370371
"StableDiffusionXLPAGInpaintPipeline",
371372
"StableDiffusionXLPAGPipeline",
372373
"StableDiffusionXLPipeline",
373-
"StableDiffusionXLModularPipeline",
374374
"StableUnCLIPImg2ImgPipeline",
375375
"StableUnCLIPPipeline",
376376
"StableVideoDiffusionPipeline",

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def _get_model(pipeline_class_name):
220220
if pipeline.__name__ == pipeline_class_name:
221221
return model_name
222222

223-
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
224223

224+
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
225225
model_name = _get_model(pipeline_class_name)
226226

227227
if model_name is not None:

src/diffusers/pipelines/modular_pipeline_builder.py

Lines changed: 44 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,22 @@
1414

1515
import inspect
1616
from dataclasses import dataclass, field
17-
from typing import Any, Dict, List, Optional, Tuple, Union
18-
import importlib
19-
from collections import OrderedDict
20-
import PIL
17+
from typing import Any, Dict, List, Tuple, Union
18+
2119
import torch
2220
from tqdm.auto import tqdm
2321

2422
from ..configuration_utils import ConfigMixin
25-
from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
26-
from ..models import ImageProjection
27-
from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
28-
from ..models.lora import adjust_lora_scale_text_encoder
2923
from ..utils import (
30-
USE_PEFT_BACKEND,
3124
is_accelerate_available,
3225
is_accelerate_version,
3326
logging,
34-
scale_lora_layers,
35-
unscale_lora_layers,
3627
)
3728
from ..utils.hub_utils import validate_hf_hub_args
38-
from ..utils.torch_utils import randn_tensor
39-
from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class
40-
from .pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4129
from .auto_pipeline import _get_model
30+
from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class
31+
from .pipeline_utils import DiffusionPipeline
32+
4233

4334
if is_accelerate_available():
4435
import accelerate
@@ -51,7 +42,6 @@
5142
}
5243

5344

54-
5545
@dataclass
5646
class PipelineState:
5747
"""
@@ -225,6 +215,7 @@ class ModularPipelineBuilder(ConfigMixin):
225215
Base class for all Modular pipelines.
226216
227217
"""
218+
228219
config_name = "model_index.json"
229220
model_cpu_offload_seq = None
230221
hf_device_map = None
@@ -316,7 +307,7 @@ def components(self) -> Dict[str, Any]:
316307
expected_components = set()
317308
for block in self.pipeline_blocks:
318309
expected_components.update(block.components.keys())
319-
310+
320311
components = {}
321312
for name in expected_components:
322313
if hasattr(self, name):
@@ -349,8 +340,8 @@ def auxiliaries(self) -> Dict[str, Any]:
349340
@property
350341
def configs(self) -> Dict[str, Any]:
351342
r"""
352-
The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the
353-
pipeline blocks.
343+
The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the pipeline
344+
blocks.
354345
355346
Returns (`dict`):
356347
A dictionary containing all the configs defined in the pipeline blocks.
@@ -393,31 +384,32 @@ def __call__(self, *args, **kwargs):
393384

394385
def remove_blocks(self, indices: Union[int, List[int]]):
395386
"""
396-
Remove one or more blocks from the pipeline by their indices and clean up associated components,
397-
configs, and auxiliaries that are no longer needed by remaining blocks.
387+
Remove one or more blocks from the pipeline by their indices and clean up associated components, configs, and
388+
auxiliaries that are no longer needed by remaining blocks.
398389
399390
Args:
400391
indices (Union[int, List[int]]): The index or list of indices of blocks to remove
401392
"""
402393
# Convert single index to list
403394
indices = [indices] if isinstance(indices, int) else indices
404-
395+
405396
# Validate indices
406397
for idx in indices:
407398
if not 0 <= idx < len(self.pipeline_blocks):
408-
raise ValueError(f"Invalid block index {idx}. Index must be between 0 and {len(self.pipeline_blocks) - 1}")
409-
399+
raise ValueError(
400+
f"Invalid block index {idx}. Index must be between 0 and {len(self.pipeline_blocks) - 1}"
401+
)
402+
410403
# Sort indices in descending order to avoid shifting issues when removing
411404
indices = sorted(indices, reverse=True)
412-
405+
413406
# Store blocks to be removed
414407
blocks_to_remove = [self.pipeline_blocks[idx] for idx in indices]
415-
408+
416409
# Remove blocks from pipeline
417410
for idx in indices:
418411
self.pipeline_blocks.pop(idx)
419412

420-
421413
# Consolidate items to remove from all blocks
422414
components_to_remove = {k: v for block in blocks_to_remove for k, v in block.components.items()}
423415
auxiliaries_to_remove = {k: v for block in blocks_to_remove for k, v in block.auxiliaries.items()}
@@ -448,15 +440,15 @@ def remove_blocks(self, indices: Union[int, List[int]]):
448440

449441
def add_blocks(self, pipeline_blocks, at: int = -1):
450442
"""Add blocks to the pipeline.
451-
443+
452444
Args:
453445
pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances.
454446
at (int, optional): Index at which to insert the blocks. Defaults to -1 (append at end).
455447
"""
456448
# Convert single block to list for uniform processing
457449
if not isinstance(pipeline_blocks, (list, tuple)):
458450
pipeline_blocks = [pipeline_blocks]
459-
451+
460452
# Validate insert_at index
461453
if at != -1 and not 0 <= at <= len(self.pipeline_blocks):
462454
raise ValueError(f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks)}")
@@ -465,24 +457,24 @@ def add_blocks(self, pipeline_blocks, at: int = -1):
465457
components_to_add = {}
466458
configs_to_add = {}
467459
auxiliaries_to_add = {}
468-
460+
469461
# Add blocks in order
470462
for i, block in enumerate(pipeline_blocks):
471463
# Add block to pipeline at specified position
472464
if at == -1:
473465
self.pipeline_blocks.append(block)
474466
else:
475467
self.pipeline_blocks.insert(at + i, block)
476-
468+
477469
# Collect components that don't already exist
478470
for k, v in block.components.items():
479471
if not hasattr(self, k) or (getattr(self, k, None) is None and v is not None):
480472
components_to_add[k] = v
481-
473+
482474
# Collect configs and auxiliaries
483475
configs_to_add.update(block.configs)
484476
auxiliaries_to_add.update(block.auxiliaries)
485-
477+
486478
# Validate all required components and auxiliaries after consolidation
487479
for block in pipeline_blocks:
488480
for required_component in block.required_components:
@@ -513,44 +505,37 @@ def add_blocks(self, pipeline_blocks, at: int = -1):
513505
if configs_to_add:
514506
self.register_to_config(**configs_to_add)
515507
for key, value in auxiliaries_to_add.items():
516-
517508
setattr(self, key, value)
518509

519510
def replace_blocks(self, pipeline_blocks, at: int):
520511
"""Replace one or more blocks in the pipeline at the specified index.
521-
512+
522513
Args:
523-
pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances
514+
pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances
524515
that will replace existing blocks.
525516
at (int): Index at which to replace the blocks.
526517
"""
527518
# Convert single block to list for uniform processing
528519
if not isinstance(pipeline_blocks, (list, tuple)):
529520
pipeline_blocks = [pipeline_blocks]
530-
521+
531522
# Validate replace_at index
532523
if not 0 <= at < len(self.pipeline_blocks):
533-
raise ValueError(
534-
f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks) - 1}"
535-
)
536-
524+
raise ValueError(f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks) - 1}")
525+
537526
# Add new blocks first
538527
self.add_blocks(pipeline_blocks, at=at)
539-
528+
540529
# Calculate indices to remove
541530
# We need to remove the original blocks that are now shifted by the length of pipeline_blocks
542-
indices_to_remove = list(range(
543-
at + len(pipeline_blocks),
544-
at + len(pipeline_blocks) * 2
545-
))
546-
531+
indices_to_remove = list(range(at + len(pipeline_blocks), at + len(pipeline_blocks) * 2))
532+
547533
# Remove the old blocks
548534
self.remove_blocks(indices_to_remove)
549535

550536
@classmethod
551537
@validate_hf_hub_args
552538
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
553-
554539
# (1) create the base pipeline
555540
cache_dir = kwargs.pop("cache_dir", None)
556541
force_download = kwargs.pop("force_download", False)
@@ -579,47 +564,41 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
579564
modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)]
580565
modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name)
581566

582-
583567
# (3) create the pipeline blocks
584568
pipeline_blocks = [
585-
block_class.from_pipe(base_pipeline)
586-
for block_class in modular_pipeline_class.default_pipeline_blocks
569+
block_class.from_pipe(base_pipeline) for block_class in modular_pipeline_class.default_pipeline_blocks
587570
]
588571

589572
# (4) create the builder
590573
builder = modular_pipeline_class()
591574
builder.add_blocks(pipeline_blocks)
592575

593576
return builder
594-
577+
595578
@classmethod
596579
def from_pipe(cls, pipeline, **kwargs):
597580
base_pipeline_class_name = pipeline.__class__.__name__
598581
modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)]
599582
modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name)
600-
583+
601584
pipeline_blocks = []
602585
# Create each block, passing only unused items that the block expects
603586
for block_class in modular_pipeline_class.default_pipeline_blocks:
604587
expected_components = set(block_class.required_components + block_class.optional_components)
605588
expected_auxiliaries = set(block_class.required_auxiliaries)
606-
589+
607590
# Get init parameters to check for expected configs
608591
init_params = inspect.signature(block_class.__init__).parameters
609592
expected_configs = {
610-
k for k in init_params
611-
if k not in expected_components
612-
and k not in expected_auxiliaries
593+
k for k in init_params if k not in expected_components and k not in expected_auxiliaries
613594
}
614-
595+
615596
block_kwargs = {}
616-
597+
617598
for key, value in kwargs.items():
618-
if (key in expected_components or
619-
key in expected_auxiliaries or
620-
key in expected_configs):
599+
if key in expected_components or key in expected_auxiliaries or key in expected_configs:
621600
block_kwargs[key] = value
622-
601+
623602
# Create the block with filtered kwargs
624603
block = block_class.from_pipe(pipeline, **block_kwargs)
625604
pipeline_blocks.append(block)
@@ -630,10 +609,10 @@ def from_pipe(cls, pipeline, **kwargs):
630609

631610
# Warn about unused kwargs
632611
unused_kwargs = {
633-
k: v for k, v in kwargs.items()
612+
k: v
613+
for k, v in kwargs.items()
634614
if not any(
635-
k in block.components or k in block.auxiliaries or k in block.configs
636-
for block in pipeline_blocks
615+
k in block.components or k in block.auxiliaries or k in block.configs for block in pipeline_blocks
637616
)
638617
}
639618
if unused_kwargs:
@@ -774,7 +753,6 @@ def __repr__(self):
774753
output += f"{name}: {config!r}\n"
775754
output += "\n"
776755

777-
778756
# List the default call parameters
779757
output += "Default Call Parameters:\n"
780758
output += "------------------------\n"

src/diffusers/pipelines/stable_diffusion_xl/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
3131
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
3232
_import_structure["pipeline_stable_diffusion_xl_modular"] = [
33+
"StableDiffusionXLControlNetDenoiseStep",
3334
"StableDiffusionXLDecodeLatentsStep",
3435
"StableDiffusionXLDenoiseStep",
3536
"StableDiffusionXLInputStep",
@@ -38,7 +39,6 @@
3839
"StableDiffusionXLPrepareLatentsStep",
3940
"StableDiffusionXLSetTimestepsStep",
4041
"StableDiffusionXLTextEncoderStep",
41-
"StableDiffusionXLControlNetDenoiseStep",
4242
]
4343

4444
if is_transformers_available() and is_flax_available():
@@ -60,6 +60,7 @@
6060
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
6161
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
6262
from .pipeline_stable_diffusion_xl_modular import (
63+
StableDiffusionXLControlNetDenoiseStep,
6364
StableDiffusionXLDecodeLatentsStep,
6465
StableDiffusionXLDenoiseStep,
6566
StableDiffusionXLInputStep,
@@ -68,7 +69,6 @@
6869
StableDiffusionXLPrepareLatentsStep,
6970
StableDiffusionXLSetTimestepsStep,
7071
StableDiffusionXLTextEncoderStep,
71-
StableDiffusionXLControlNetDenoiseStep,
7272
)
7373

7474
try:

0 commit comments

Comments
 (0)