Skip to content

Commit 37a5f1b

Browse files
authored
Experimental per control type scale for ControlNet Union (#10723)
* ControlNet Union scale * fix * universal interface * from_multi * from_multi
1 parent 501d9de commit 37a5f1b

File tree

3 files changed

+47
-39
lines changed

3 files changed

+47
-39
lines changed

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -605,12 +605,13 @@ def forward(
605605
controlnet_cond: List[torch.Tensor],
606606
control_type: torch.Tensor,
607607
control_type_idx: List[int],
608-
conditioning_scale: float = 1.0,
608+
conditioning_scale: Union[float, List[float]] = 1.0,
609609
class_labels: Optional[torch.Tensor] = None,
610610
timestep_cond: Optional[torch.Tensor] = None,
611611
attention_mask: Optional[torch.Tensor] = None,
612612
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
613613
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
614+
from_multi: bool = False,
614615
guess_mode: bool = False,
615616
return_dict: bool = True,
616617
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
@@ -647,6 +648,8 @@ def forward(
647648
Additional conditions for the Stable Diffusion XL UNet.
648649
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
649650
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
651+
from_multi (`bool`, defaults to `False`):
652+
Use standard scaling when called from `MultiControlNetUnionModel`.
650653
guess_mode (`bool`, defaults to `False`):
651654
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
652655
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
@@ -658,6 +661,9 @@ def forward(
658661
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
659662
returned where the first element is the sample tensor.
660663
"""
664+
if isinstance(conditioning_scale, float):
665+
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
666+
661667
# check channel order
662668
channel_order = self.config.controlnet_conditioning_channel_order
663669

@@ -742,12 +748,16 @@ def forward(
742748
inputs = []
743749
condition_list = []
744750

745-
for cond, control_idx in zip(controlnet_cond, control_type_idx):
751+
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
746752
condition = self.controlnet_cond_embedding(cond)
747753
feat_seq = torch.mean(condition, dim=(2, 3))
748754
feat_seq = feat_seq + self.task_embedding[control_idx]
749-
inputs.append(feat_seq.unsqueeze(1))
750-
condition_list.append(condition)
755+
if from_multi:
756+
inputs.append(feat_seq.unsqueeze(1))
757+
condition_list.append(condition)
758+
else:
759+
inputs.append(feat_seq.unsqueeze(1) * scale)
760+
condition_list.append(condition * scale)
751761

752762
condition = sample
753763
feat_seq = torch.mean(condition, dim=(2, 3))
@@ -759,10 +769,13 @@ def forward(
759769
x = layer(x)
760770

761771
controlnet_cond_fuser = sample * 0.0
762-
for idx, condition in enumerate(condition_list[:-1]):
772+
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
763773
alpha = self.spatial_ch_projs(x[:, idx])
764774
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
765-
controlnet_cond_fuser += condition + alpha
775+
if from_multi:
776+
controlnet_cond_fuser += condition + alpha
777+
else:
778+
controlnet_cond_fuser += condition + alpha * scale
766779

767780
sample = sample + controlnet_cond_fuser
768781

@@ -806,12 +819,13 @@ def forward(
806819
# 6. scaling
807820
if guess_mode and not self.config.global_pool_conditions:
808821
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
809-
scales = scales * conditioning_scale
822+
if from_multi:
823+
scales = scales * conditioning_scale[0]
810824
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
811825
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
812-
else:
813-
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
814-
mid_block_res_sample = mid_block_res_sample * conditioning_scale
826+
elif from_multi:
827+
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828+
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
815829

816830
if self.config.global_pool_conditions:
817831
down_block_res_samples = [

src/diffusers/models/controlnets/multicontrolnet_union.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ def forward(
4747
guess_mode: bool = False,
4848
return_dict: bool = True,
4949
) -> Union[ControlNetOutput, Tuple]:
50+
down_block_res_samples, mid_block_res_sample = None, None
5051
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
5152
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
5253
):
54+
if scale == 0.0:
55+
continue
5356
down_samples, mid_sample = controlnet(
5457
sample=sample,
5558
timestep=timestep,
@@ -63,12 +66,13 @@ def forward(
6366
attention_mask=attention_mask,
6467
added_cond_kwargs=added_cond_kwargs,
6568
cross_attention_kwargs=cross_attention_kwargs,
69+
from_multi=True,
6670
guess_mode=guess_mode,
6771
return_dict=return_dict,
6872
)
6973

7074
# merge samples
71-
if i == 0:
75+
if down_block_res_samples is None and mid_block_res_sample is None:
7276
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
7377
else:
7478
down_block_res_samples = [

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -757,15 +757,9 @@ def check_inputs(
757757
for images_ in image:
758758
for image_ in images_:
759759
self.check_image(image_, prompt, prompt_embeds)
760-
else:
761-
assert False
762760

763761
# Check `controlnet_conditioning_scale`
764-
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
765-
if isinstance(controlnet, ControlNetUnionModel):
766-
if not isinstance(controlnet_conditioning_scale, float):
767-
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
768-
elif isinstance(controlnet, MultiControlNetUnionModel):
762+
if isinstance(controlnet, MultiControlNetUnionModel):
769763
if isinstance(controlnet_conditioning_scale, list):
770764
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
771765
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
@@ -776,8 +770,6 @@ def check_inputs(
776770
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
777771
" the same length as the number of controlnets"
778772
)
779-
else:
780-
assert False
781773

782774
if len(control_guidance_start) != len(control_guidance_end):
783775
raise ValueError(
@@ -808,8 +800,6 @@ def check_inputs(
808800
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
809801
if max(_control_mode) >= _controlnet.config.num_control_type:
810802
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
811-
else:
812-
assert False
813803

814804
# Equal number of `image` and `control_mode` elements
815805
if isinstance(controlnet, ControlNetUnionModel):
@@ -823,8 +813,6 @@ def check_inputs(
823813

824814
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
825815
raise ValueError("Expected len(control_image) == len(control_mode)")
826-
else:
827-
assert False
828816

829817
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
830818
raise ValueError(
@@ -1201,28 +1189,33 @@ def __call__(
12011189

12021190
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
12031191

1192+
if not isinstance(control_image, list):
1193+
control_image = [control_image]
1194+
else:
1195+
control_image = control_image.copy()
1196+
1197+
if not isinstance(control_mode, list):
1198+
control_mode = [control_mode]
1199+
1200+
if isinstance(controlnet, MultiControlNetUnionModel):
1201+
control_image = [[item] for item in control_image]
1202+
control_mode = [[item] for item in control_mode]
1203+
12041204
# align format for control guidance
12051205
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
12061206
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
12071207
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
12081208
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
12091209
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1210-
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
1210+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
12111211
control_guidance_start, control_guidance_end = (
12121212
mult * [control_guidance_start],
12131213
mult * [control_guidance_end],
12141214
)
12151215

1216-
if not isinstance(control_image, list):
1217-
control_image = [control_image]
1218-
else:
1219-
control_image = control_image.copy()
1220-
1221-
if not isinstance(control_mode, list):
1222-
control_mode = [control_mode]
1223-
1224-
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
1225-
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1216+
if isinstance(controlnet_conditioning_scale, float):
1217+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
1218+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
12261219

12271220
# 1. Check inputs
12281221
self.check_inputs(
@@ -1357,9 +1350,6 @@ def __call__(
13571350
control_image = control_images
13581351
height, width = control_image[0][0].shape[-2:]
13591352

1360-
else:
1361-
assert False
1362-
13631353
# 5. Prepare timesteps
13641354
timesteps, num_inference_steps = retrieve_timesteps(
13651355
self.scheduler, num_inference_steps, device, timesteps, sigmas
@@ -1397,7 +1387,7 @@ def __call__(
13971387
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
13981388
for s, e in zip(control_guidance_start, control_guidance_end)
13991389
]
1400-
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
1390+
controlnet_keep.append(keeps)
14011391

14021392
# 7.2 Prepare added time ids & embeddings
14031393
original_size = original_size or (height, width)

0 commit comments

Comments
 (0)