@@ -757,15 +757,9 @@ def check_inputs(
757
757
for images_ in image :
758
758
for image_ in images_ :
759
759
self .check_image (image_ , prompt , prompt_embeds )
760
- else :
761
- assert False
762
760
763
761
# Check `controlnet_conditioning_scale`
764
- # TODO Update for https://github.com/huggingface/diffusers/pull/10723
765
- if isinstance (controlnet , ControlNetUnionModel ):
766
- if not isinstance (controlnet_conditioning_scale , float ):
767
- raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
768
- elif isinstance (controlnet , MultiControlNetUnionModel ):
762
+ if isinstance (controlnet , MultiControlNetUnionModel ):
769
763
if isinstance (controlnet_conditioning_scale , list ):
770
764
if any (isinstance (i , list ) for i in controlnet_conditioning_scale ):
771
765
raise ValueError ("A single batch of multiple conditionings is not supported at the moment." )
@@ -776,8 +770,6 @@ def check_inputs(
776
770
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
777
771
" the same length as the number of controlnets"
778
772
)
779
- else :
780
- assert False
781
773
782
774
if len (control_guidance_start ) != len (control_guidance_end ):
783
775
raise ValueError (
@@ -808,8 +800,6 @@ def check_inputs(
808
800
for _control_mode , _controlnet in zip (control_mode , self .controlnet .nets ):
809
801
if max (_control_mode ) >= _controlnet .config .num_control_type :
810
802
raise ValueError (f"control_mode: must be lower than { _controlnet .config .num_control_type } ." )
811
- else :
812
- assert False
813
803
814
804
# Equal number of `image` and `control_mode` elements
815
805
if isinstance (controlnet , ControlNetUnionModel ):
@@ -823,8 +813,6 @@ def check_inputs(
823
813
824
814
elif sum (len (x ) for x in image ) != sum (len (x ) for x in control_mode ):
825
815
raise ValueError ("Expected len(control_image) == len(control_mode)" )
826
- else :
827
- assert False
828
816
829
817
if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
830
818
raise ValueError (
@@ -1201,28 +1189,33 @@ def __call__(
1201
1189
1202
1190
controlnet = self .controlnet ._orig_mod if is_compiled_module (self .controlnet ) else self .controlnet
1203
1191
1192
+ if not isinstance (control_image , list ):
1193
+ control_image = [control_image ]
1194
+ else :
1195
+ control_image = control_image .copy ()
1196
+
1197
+ if not isinstance (control_mode , list ):
1198
+ control_mode = [control_mode ]
1199
+
1200
+ if isinstance (controlnet , MultiControlNetUnionModel ):
1201
+ control_image = [[item ] for item in control_image ]
1202
+ control_mode = [[item ] for item in control_mode ]
1203
+
1204
1204
# align format for control guidance
1205
1205
if not isinstance (control_guidance_start , list ) and isinstance (control_guidance_end , list ):
1206
1206
control_guidance_start = len (control_guidance_end ) * [control_guidance_start ]
1207
1207
elif not isinstance (control_guidance_end , list ) and isinstance (control_guidance_start , list ):
1208
1208
control_guidance_end = len (control_guidance_start ) * [control_guidance_end ]
1209
1209
elif not isinstance (control_guidance_start , list ) and not isinstance (control_guidance_end , list ):
1210
- mult = len (controlnet .nets ) if isinstance (controlnet , MultiControlNetUnionModel ) else 1
1210
+ mult = len (controlnet .nets ) if isinstance (controlnet , MultiControlNetUnionModel ) else len ( control_mode )
1211
1211
control_guidance_start , control_guidance_end = (
1212
1212
mult * [control_guidance_start ],
1213
1213
mult * [control_guidance_end ],
1214
1214
)
1215
1215
1216
- if not isinstance (control_image , list ):
1217
- control_image = [control_image ]
1218
- else :
1219
- control_image = control_image .copy ()
1220
-
1221
- if not isinstance (control_mode , list ):
1222
- control_mode = [control_mode ]
1223
-
1224
- if isinstance (controlnet , MultiControlNetUnionModel ) and isinstance (controlnet_conditioning_scale , float ):
1225
- controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (controlnet .nets )
1216
+ if isinstance (controlnet_conditioning_scale , float ):
1217
+ mult = len (controlnet .nets ) if isinstance (controlnet , MultiControlNetUnionModel ) else len (control_mode )
1218
+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * mult
1226
1219
1227
1220
# 1. Check inputs
1228
1221
self .check_inputs (
@@ -1357,9 +1350,6 @@ def __call__(
1357
1350
control_image = control_images
1358
1351
height , width = control_image [0 ][0 ].shape [- 2 :]
1359
1352
1360
- else :
1361
- assert False
1362
-
1363
1353
# 5. Prepare timesteps
1364
1354
timesteps , num_inference_steps = retrieve_timesteps (
1365
1355
self .scheduler , num_inference_steps , device , timesteps , sigmas
@@ -1397,7 +1387,7 @@ def __call__(
1397
1387
1.0 - float (i / len (timesteps ) < s or (i + 1 ) / len (timesteps ) > e )
1398
1388
for s , e in zip (control_guidance_start , control_guidance_end )
1399
1389
]
1400
- controlnet_keep .append (keeps [ 0 ] if isinstance ( controlnet , ControlNetUnionModel ) else keeps )
1390
+ controlnet_keep .append (keeps )
1401
1391
1402
1392
# 7.2 Prepare added time ids & embeddings
1403
1393
original_size = original_size or (height , width )
0 commit comments