14
14
15
15
import inspect
16
16
from dataclasses import dataclass , field
17
- from typing import Any , Dict , List , Optional , Tuple , Union
18
- import importlib
19
- from collections import OrderedDict
20
- import PIL
17
+ from typing import Any , Dict , List , Tuple , Union
18
+
21
19
import torch
22
20
from tqdm .auto import tqdm
23
21
24
22
from ..configuration_utils import ConfigMixin
25
- from ..loaders import StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
26
- from ..models import ImageProjection
27
- from ..models .attention_processor import AttnProcessor2_0 , XFormersAttnProcessor
28
- from ..models .lora import adjust_lora_scale_text_encoder
29
23
from ..utils import (
30
- USE_PEFT_BACKEND ,
31
24
is_accelerate_available ,
32
25
is_accelerate_version ,
33
26
logging ,
34
- scale_lora_layers ,
35
- unscale_lora_layers ,
36
27
)
37
28
from ..utils .hub_utils import validate_hf_hub_args
38
- from ..utils .torch_utils import randn_tensor
39
- from .pipeline_loading_utils import _fetch_class_library_tuple , _get_pipeline_class
40
- from .pipeline_utils import DiffusionPipeline , StableDiffusionMixin
41
29
from .auto_pipeline import _get_model
30
+ from .pipeline_loading_utils import _fetch_class_library_tuple , _get_pipeline_class
31
+ from .pipeline_utils import DiffusionPipeline
32
+
42
33
43
34
if is_accelerate_available ():
44
35
import accelerate
51
42
}
52
43
53
44
54
-
55
45
@dataclass
56
46
class PipelineState :
57
47
"""
@@ -225,6 +215,7 @@ class ModularPipelineBuilder(ConfigMixin):
225
215
Base class for all Modular pipelines.
226
216
227
217
"""
218
+
228
219
config_name = "model_index.json"
229
220
model_cpu_offload_seq = None
230
221
hf_device_map = None
@@ -316,7 +307,7 @@ def components(self) -> Dict[str, Any]:
316
307
expected_components = set ()
317
308
for block in self .pipeline_blocks :
318
309
expected_components .update (block .components .keys ())
319
-
310
+
320
311
components = {}
321
312
for name in expected_components :
322
313
if hasattr (self , name ):
@@ -349,8 +340,8 @@ def auxiliaries(self) -> Dict[str, Any]:
349
340
@property
350
341
def configs (self ) -> Dict [str , Any ]:
351
342
r"""
352
- The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the
353
- pipeline blocks.
343
+ The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the pipeline
344
+ blocks.
354
345
355
346
Returns (`dict`):
356
347
A dictionary containing all the configs defined in the pipeline blocks.
@@ -393,31 +384,32 @@ def __call__(self, *args, **kwargs):
393
384
394
385
def remove_blocks (self , indices : Union [int , List [int ]]):
395
386
"""
396
- Remove one or more blocks from the pipeline by their indices and clean up associated components,
397
- configs, and auxiliaries that are no longer needed by remaining blocks.
387
+ Remove one or more blocks from the pipeline by their indices and clean up associated components, configs, and
388
+ auxiliaries that are no longer needed by remaining blocks.
398
389
399
390
Args:
400
391
indices (Union[int, List[int]]): The index or list of indices of blocks to remove
401
392
"""
402
393
# Convert single index to list
403
394
indices = [indices ] if isinstance (indices , int ) else indices
404
-
395
+
405
396
# Validate indices
406
397
for idx in indices :
407
398
if not 0 <= idx < len (self .pipeline_blocks ):
408
- raise ValueError (f"Invalid block index { idx } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } " )
409
-
399
+ raise ValueError (
400
+ f"Invalid block index { idx } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } "
401
+ )
402
+
410
403
# Sort indices in descending order to avoid shifting issues when removing
411
404
indices = sorted (indices , reverse = True )
412
-
405
+
413
406
# Store blocks to be removed
414
407
blocks_to_remove = [self .pipeline_blocks [idx ] for idx in indices ]
415
-
408
+
416
409
# Remove blocks from pipeline
417
410
for idx in indices :
418
411
self .pipeline_blocks .pop (idx )
419
412
420
-
421
413
# Consolidate items to remove from all blocks
422
414
components_to_remove = {k : v for block in blocks_to_remove for k , v in block .components .items ()}
423
415
auxiliaries_to_remove = {k : v for block in blocks_to_remove for k , v in block .auxiliaries .items ()}
@@ -448,15 +440,15 @@ def remove_blocks(self, indices: Union[int, List[int]]):
448
440
449
441
def add_blocks (self , pipeline_blocks , at : int = - 1 ):
450
442
"""Add blocks to the pipeline.
451
-
443
+
452
444
Args:
453
445
pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances.
454
446
at (int, optional): Index at which to insert the blocks. Defaults to -1 (append at end).
455
447
"""
456
448
# Convert single block to list for uniform processing
457
449
if not isinstance (pipeline_blocks , (list , tuple )):
458
450
pipeline_blocks = [pipeline_blocks ]
459
-
451
+
460
452
# Validate insert_at index
461
453
if at != - 1 and not 0 <= at <= len (self .pipeline_blocks ):
462
454
raise ValueError (f"Invalid at index { at } . Index must be between 0 and { len (self .pipeline_blocks )} " )
@@ -465,24 +457,24 @@ def add_blocks(self, pipeline_blocks, at: int = -1):
465
457
components_to_add = {}
466
458
configs_to_add = {}
467
459
auxiliaries_to_add = {}
468
-
460
+
469
461
# Add blocks in order
470
462
for i , block in enumerate (pipeline_blocks ):
471
463
# Add block to pipeline at specified position
472
464
if at == - 1 :
473
465
self .pipeline_blocks .append (block )
474
466
else :
475
467
self .pipeline_blocks .insert (at + i , block )
476
-
468
+
477
469
# Collect components that don't already exist
478
470
for k , v in block .components .items ():
479
471
if not hasattr (self , k ) or (getattr (self , k , None ) is None and v is not None ):
480
472
components_to_add [k ] = v
481
-
473
+
482
474
# Collect configs and auxiliaries
483
475
configs_to_add .update (block .configs )
484
476
auxiliaries_to_add .update (block .auxiliaries )
485
-
477
+
486
478
# Validate all required components and auxiliaries after consolidation
487
479
for block in pipeline_blocks :
488
480
for required_component in block .required_components :
@@ -513,44 +505,37 @@ def add_blocks(self, pipeline_blocks, at: int = -1):
513
505
if configs_to_add :
514
506
self .register_to_config (** configs_to_add )
515
507
for key , value in auxiliaries_to_add .items ():
516
-
517
508
setattr (self , key , value )
518
509
519
510
def replace_blocks (self , pipeline_blocks , at : int ):
520
511
"""Replace one or more blocks in the pipeline at the specified index.
521
-
512
+
522
513
Args:
523
- pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances
514
+ pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances
524
515
that will replace existing blocks.
525
516
at (int): Index at which to replace the blocks.
526
517
"""
527
518
# Convert single block to list for uniform processing
528
519
if not isinstance (pipeline_blocks , (list , tuple )):
529
520
pipeline_blocks = [pipeline_blocks ]
530
-
521
+
531
522
# Validate replace_at index
532
523
if not 0 <= at < len (self .pipeline_blocks ):
533
- raise ValueError (
534
- f"Invalid at index { at } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } "
535
- )
536
-
524
+ raise ValueError (f"Invalid at index { at } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } " )
525
+
537
526
# Add new blocks first
538
527
self .add_blocks (pipeline_blocks , at = at )
539
-
528
+
540
529
# Calculate indices to remove
541
530
# We need to remove the original blocks that are now shifted by the length of pipeline_blocks
542
- indices_to_remove = list (range (
543
- at + len (pipeline_blocks ),
544
- at + len (pipeline_blocks ) * 2
545
- ))
546
-
531
+ indices_to_remove = list (range (at + len (pipeline_blocks ), at + len (pipeline_blocks ) * 2 ))
532
+
547
533
# Remove the old blocks
548
534
self .remove_blocks (indices_to_remove )
549
535
550
536
@classmethod
551
537
@validate_hf_hub_args
552
538
def from_pretrained (cls , pretrained_model_or_path , ** kwargs ):
553
-
554
539
# (1) create the base pipeline
555
540
cache_dir = kwargs .pop ("cache_dir" , None )
556
541
force_download = kwargs .pop ("force_download" , False )
@@ -579,47 +564,41 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
579
564
modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING [_get_model (base_pipeline_class_name )]
580
565
modular_pipeline_class = _get_pipeline_class (cls , config = None , class_name = modular_pipeline_class_name )
581
566
582
-
583
567
# (3) create the pipeline blocks
584
568
pipeline_blocks = [
585
- block_class .from_pipe (base_pipeline )
586
- for block_class in modular_pipeline_class .default_pipeline_blocks
569
+ block_class .from_pipe (base_pipeline ) for block_class in modular_pipeline_class .default_pipeline_blocks
587
570
]
588
571
589
572
# (4) create the builder
590
573
builder = modular_pipeline_class ()
591
574
builder .add_blocks (pipeline_blocks )
592
575
593
576
return builder
594
-
577
+
595
578
@classmethod
596
579
def from_pipe (cls , pipeline , ** kwargs ):
597
580
base_pipeline_class_name = pipeline .__class__ .__name__
598
581
modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING [_get_model (base_pipeline_class_name )]
599
582
modular_pipeline_class = _get_pipeline_class (cls , config = None , class_name = modular_pipeline_class_name )
600
-
583
+
601
584
pipeline_blocks = []
602
585
# Create each block, passing only unused items that the block expects
603
586
for block_class in modular_pipeline_class .default_pipeline_blocks :
604
587
expected_components = set (block_class .required_components + block_class .optional_components )
605
588
expected_auxiliaries = set (block_class .required_auxiliaries )
606
-
589
+
607
590
# Get init parameters to check for expected configs
608
591
init_params = inspect .signature (block_class .__init__ ).parameters
609
592
expected_configs = {
610
- k for k in init_params
611
- if k not in expected_components
612
- and k not in expected_auxiliaries
593
+ k for k in init_params if k not in expected_components and k not in expected_auxiliaries
613
594
}
614
-
595
+
615
596
block_kwargs = {}
616
-
597
+
617
598
for key , value in kwargs .items ():
618
- if (key in expected_components or
619
- key in expected_auxiliaries or
620
- key in expected_configs ):
599
+ if key in expected_components or key in expected_auxiliaries or key in expected_configs :
621
600
block_kwargs [key ] = value
622
-
601
+
623
602
# Create the block with filtered kwargs
624
603
block = block_class .from_pipe (pipeline , ** block_kwargs )
625
604
pipeline_blocks .append (block )
@@ -630,10 +609,10 @@ def from_pipe(cls, pipeline, **kwargs):
630
609
631
610
# Warn about unused kwargs
632
611
unused_kwargs = {
633
- k : v for k , v in kwargs .items ()
612
+ k : v
613
+ for k , v in kwargs .items ()
634
614
if not any (
635
- k in block .components or k in block .auxiliaries or k in block .configs
636
- for block in pipeline_blocks
615
+ k in block .components or k in block .auxiliaries or k in block .configs for block in pipeline_blocks
637
616
)
638
617
}
639
618
if unused_kwargs :
@@ -774,7 +753,6 @@ def __repr__(self):
774
753
output += f"{ name } : { config !r} \n "
775
754
output += "\n "
776
755
777
-
778
756
# List the default call parameters
779
757
output += "Default Call Parameters:\n "
780
758
output += "------------------------\n "
0 commit comments