Skip to content

Commit 8b811fe

Browse files
committed
refactor, from_pretrained, from_pipe, remove_blocks, replace_blocks
1 parent 37e8dc7 commit 8b811fe

File tree

7 files changed

+2411
-1945
lines changed

7 files changed

+2411
-1945
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
"DDIMPipeline",
146146
"DDPMPipeline",
147147
"DiffusionPipeline",
148+
"ModularPipelineBuilder",
148149
"DiTPipeline",
149150
"ImagePipelineOutput",
150151
"KarrasVePipeline",
@@ -369,6 +370,7 @@
369370
"StableDiffusionXLPAGInpaintPipeline",
370371
"StableDiffusionXLPAGPipeline",
371372
"StableDiffusionXLPipeline",
373+
"StableDiffusionXLModularPipeline",
372374
"StableUnCLIPImg2ImgPipeline",
373375
"StableUnCLIPPipeline",
374376
"StableVideoDiffusionPipeline",
@@ -626,6 +628,7 @@
626628
KarrasVePipeline,
627629
LDMPipeline,
628630
LDMSuperResolutionPipeline,
631+
ModularPipelineBuilder,
629632
PNDMPipeline,
630633
RePaintPipeline,
631634
ScoreSdeVePipeline,
@@ -819,6 +822,7 @@
819822
StableDiffusionXLImg2ImgPipeline,
820823
StableDiffusionXLInpaintPipeline,
821824
StableDiffusionXLInstructPix2PixPipeline,
825+
StableDiffusionXLModularPipeline,
822826
StableDiffusionXLPAGImg2ImgPipeline,
823827
StableDiffusionXLPAGInpaintPipeline,
824828
StableDiffusionXLPAGPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"AutoPipelineForInpainting",
4747
"AutoPipelineForText2Image",
4848
]
49+
_import_structure["modular_pipeline_builder"] = ["ModularPipelineBuilder"]
4950
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5051
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
5152
_import_structure["ddim"] = ["DDIMPipeline"]
@@ -296,6 +297,7 @@
296297
"StableDiffusionXLInpaintPipeline",
297298
"StableDiffusionXLInstructPix2PixPipeline",
298299
"StableDiffusionXLPipeline",
300+
"StableDiffusionXLModularPipeline",
299301
]
300302
)
301303
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
@@ -432,6 +434,7 @@
432434
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
433435
from .dit import DiTPipeline
434436
from .latent_diffusion import LDMSuperResolutionPipeline
437+
from .modular_pipeline_builder import ModularPipelineBuilder
435438
from .pipeline_utils import (
436439
AudioPipelineOutput,
437440
DiffusionPipeline,
@@ -620,6 +623,7 @@
620623
StableDiffusionXLImg2ImgPipeline,
621624
StableDiffusionXLInpaintPipeline,
622625
StableDiffusionXLInstructPix2PixPipeline,
626+
StableDiffusionXLModularPipeline,
623627
StableDiffusionXLPipeline,
624628
)
625629
from .stable_video_diffusion import StableVideoDiffusionPipeline

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,15 @@ def _get_connected_pipeline(pipeline_cls):
214214
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
215215

216216

217+
def _get_model(pipeline_class_name):
218+
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
219+
for model_name, pipeline in task_mapping.items():
220+
if pipeline.__name__ == pipeline_class_name:
221+
return model_name
222+
217223
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
218-
def get_model(pipeline_class_name):
219-
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
220-
for model_name, pipeline in task_mapping.items():
221-
if pipeline.__name__ == pipeline_class_name:
222-
return model_name
223224

224-
model_name = get_model(pipeline_class_name)
225+
model_name = _get_model(pipeline_class_name)
225226

226227
if model_name is not None:
227228
task_class = mapping.get(model_name, None)

0 commit comments

Comments
 (0)