@@ -605,7 +605,7 @@ def forward(
605
605
controlnet_cond : List [torch .Tensor ],
606
606
control_type : torch .Tensor ,
607
607
control_type_idx : List [int ],
608
- conditioning_scale : float = 1.0 ,
608
+ conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
609
609
class_labels : Optional [torch .Tensor ] = None ,
610
610
timestep_cond : Optional [torch .Tensor ] = None ,
611
611
attention_mask : Optional [torch .Tensor ] = None ,
@@ -658,6 +658,9 @@ def forward(
658
658
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
659
659
returned where the first element is the sample tensor.
660
660
"""
661
+ if isinstance (conditioning_scale , float ):
662
+ conditioning_scale = [conditioning_scale ] * len (controlnet_cond )
663
+
661
664
# check channel order
662
665
channel_order = self .config .controlnet_conditioning_channel_order
663
666
@@ -742,12 +745,12 @@ def forward(
742
745
inputs = []
743
746
condition_list = []
744
747
745
- for cond , control_idx in zip (controlnet_cond , control_type_idx ):
748
+ for cond , control_idx , scale in zip (controlnet_cond , control_type_idx , conditioning_scale ):
746
749
condition = self .controlnet_cond_embedding (cond )
747
750
feat_seq = torch .mean (condition , dim = (2 , 3 ))
748
751
feat_seq = feat_seq + self .task_embedding [control_idx ]
749
- inputs .append (feat_seq .unsqueeze (1 ))
750
- condition_list .append (condition )
752
+ inputs .append (feat_seq .unsqueeze (1 ) * scale )
753
+ condition_list .append (condition * scale )
751
754
752
755
condition = sample
753
756
feat_seq = torch .mean (condition , dim = (2 , 3 ))
@@ -759,10 +762,10 @@ def forward(
759
762
x = layer (x )
760
763
761
764
controlnet_cond_fuser = sample * 0.0
762
- for idx , condition in enumerate (condition_list [:- 1 ]):
765
+ for ( idx , condition ), scale in zip ( enumerate (condition_list [:- 1 ]), conditioning_scale ):
763
766
alpha = self .spatial_ch_projs (x [:, idx ])
764
767
alpha = alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
765
- controlnet_cond_fuser += condition + alpha
768
+ controlnet_cond_fuser += condition + alpha * scale
766
769
767
770
sample = sample + controlnet_cond_fuser
768
771
@@ -806,12 +809,8 @@ def forward(
806
809
# 6. scaling
807
810
if guess_mode and not self .config .global_pool_conditions :
808
811
scales = torch .logspace (- 1 , 0 , len (down_block_res_samples ) + 1 , device = sample .device ) # 0.1 to 1.0
809
- scales = scales * conditioning_scale
810
812
down_block_res_samples = [sample * scale for sample , scale in zip (down_block_res_samples , scales )]
811
813
mid_block_res_sample = mid_block_res_sample * scales [- 1 ] # last one
812
- else :
813
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples ]
814
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
815
814
816
815
if self .config .global_pool_conditions :
817
816
down_block_res_samples = [
0 commit comments