Skip to content

Commit 0914fa3

Browse files
committed
ControlNet Union scale
1 parent 5b1dcd1 commit 0914fa3

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def forward(
605605
controlnet_cond: List[torch.Tensor],
606606
control_type: torch.Tensor,
607607
control_type_idx: List[int],
608-
conditioning_scale: float = 1.0,
608+
conditioning_scale: Union[float, List[float]] = 1.0,
609609
class_labels: Optional[torch.Tensor] = None,
610610
timestep_cond: Optional[torch.Tensor] = None,
611611
attention_mask: Optional[torch.Tensor] = None,
@@ -658,6 +658,9 @@ def forward(
658658
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
659659
returned where the first element is the sample tensor.
660660
"""
661+
if isinstance(conditioning_scale, float):
662+
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
663+
661664
# check channel order
662665
channel_order = self.config.controlnet_conditioning_channel_order
663666

@@ -742,12 +745,12 @@ def forward(
742745
inputs = []
743746
condition_list = []
744747

745-
for cond, control_idx in zip(controlnet_cond, control_type_idx):
748+
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
746749
condition = self.controlnet_cond_embedding(cond)
747750
feat_seq = torch.mean(condition, dim=(2, 3))
748751
feat_seq = feat_seq + self.task_embedding[control_idx]
749-
inputs.append(feat_seq.unsqueeze(1))
750-
condition_list.append(condition)
752+
inputs.append(feat_seq.unsqueeze(1) * scale)
753+
condition_list.append(condition * scale)
751754

752755
condition = sample
753756
feat_seq = torch.mean(condition, dim=(2, 3))
@@ -759,10 +762,10 @@ def forward(
759762
x = layer(x)
760763

761764
controlnet_cond_fuser = sample * 0.0
762-
for idx, condition in enumerate(condition_list[:-1]):
765+
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
763766
alpha = self.spatial_ch_projs(x[:, idx])
764767
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
765-
controlnet_cond_fuser += condition + alpha
768+
controlnet_cond_fuser += condition + alpha * scale
766769

767770
sample = sample + controlnet_cond_fuser
768771

@@ -806,12 +809,8 @@ def forward(
806809
# 6. scaling
807810
if guess_mode and not self.config.global_pool_conditions:
808811
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
809-
scales = scales * conditioning_scale
810812
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
811813
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
812-
else:
813-
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
814-
mid_block_res_sample = mid_block_res_sample * conditioning_scale
815814

816815
if self.config.global_pool_conditions:
817816
down_block_res_samples = [

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,20 +1132,29 @@ def __call__(
11321132

11331133
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
11341134

1135+
if not isinstance(control_mode, list):
1136+
control_mode = [control_mode]
1137+
11351138
# align format for control guidance
11361139
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
11371140
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
11381141
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
11391142
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1143+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1144+
mult = len(control_mode)
1145+
control_guidance_start, control_guidance_end = (
1146+
mult * [control_guidance_start],
1147+
mult * [control_guidance_end],
1148+
)
1149+
1150+
if isinstance(controlnet_conditioning_scale, float):
1151+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(control_mode)
11401152

11411153
if not isinstance(control_image, list):
11421154
control_image = [control_image]
11431155
else:
11441156
control_image = control_image.copy()
11451157

1146-
if not isinstance(control_mode, list):
1147-
control_mode = [control_mode]
1148-
11491158
if len(control_image) != len(control_mode):
11501159
raise ValueError("Expected len(control_image) == len(control_type)")
11511160

@@ -1278,10 +1287,11 @@ def __call__(
12781287
# 7.1 Create tensor stating which controlnets to keep
12791288
controlnet_keep = []
12801289
for i in range(len(timesteps)):
1281-
controlnet_keep.append(
1282-
1.0
1283-
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
1284-
)
1290+
keeps = [
1291+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1292+
for s, e in zip(control_guidance_start, control_guidance_end)
1293+
]
1294+
controlnet_keep.append(keeps)
12851295

12861296
# 7.2 Prepare added time ids & embeddings
12871297
original_size = original_size or (height, width)

0 commit comments

Comments
 (0)