@@ -787,8 +787,8 @@ def extra_conds(self, **kwargs):
787787 return out
788788
789789class Flux (BaseModel ):
790- def __init__ (self , model_config , model_type = ModelType .FLUX , device = None ):
791- super ().__init__ (model_config , model_type , device = device , unet_model = comfy . ldm . flux . model . Flux )
790+ def __init__ (self , model_config , model_type = ModelType .FLUX , device = None , unet_model = comfy . ldm . flux . model . Flux ):
791+ super ().__init__ (model_config , model_type , device = device , unet_model = unet_model )
792792
793793 def concat_cond (self , ** kwargs ):
794794 try :
@@ -1110,63 +1110,14 @@ def extra_conds(self, **kwargs):
11101110 out ['image_cond' ] = comfy .conds .CONDNoiseShape (self .process_latent_in (image_cond ))
11111111 return out
11121112
1113- class Chroma (BaseModel ):
1113+ class Chroma (Flux ):
11141114 def __init__ (self , model_config , model_type = ModelType .FLOW , device = None ):
11151115 super ().__init__ (model_config , model_type , device = device , unet_model = comfy .ldm .chroma .model .Chroma )
11161116
1117- def concat_cond (self , ** kwargs ):
1118- try :
1119- #Handle Flux control loras dynamically changing the img_in weight.
1120- num_channels = self .diffusion_model .img_in .weight .shape [1 ]
1121- except :
1122- #Some cases like tensorrt might not have the weights accessible
1123- num_channels = self .model_config .unet_config ["in_channels" ]
1124-
1125- out_channels = self .model_config .unet_config ["out_channels" ]
1126-
1127- if num_channels <= out_channels :
1128- return None
1129-
1130- image = kwargs .get ("concat_latent_image" , None )
1131- noise = kwargs .get ("noise" , None )
1132- device = kwargs ["device" ]
1133-
1134- if image is None :
1135- image = torch .zeros_like (noise )
1136-
1137- image = utils .common_upscale (image .to (device ), noise .shape [- 1 ], noise .shape [- 2 ], "bilinear" , "center" )
1138- image = utils .resize_to_batch_size (image , noise .shape [0 ])
1139- image = self .process_latent_in (image )
1140- if num_channels <= out_channels * 2 :
1141- return image
1142-
1143- #inpaint model
1144- mask = kwargs .get ("concat_mask" , kwargs .get ("denoise_mask" , None ))
1145- if mask is None :
1146- mask = torch .ones_like (noise )[:, :1 ]
1147-
1148- mask = torch .mean (mask , dim = 1 , keepdim = True )
1149- mask = utils .common_upscale (mask .to (device ), noise .shape [- 1 ] * 8 , noise .shape [- 2 ] * 8 , "bilinear" , "center" )
1150- mask = mask .view (mask .shape [0 ], mask .shape [2 ] // 8 , 8 , mask .shape [3 ] // 8 , 8 ).permute (0 , 2 , 4 , 1 , 3 ).reshape (mask .shape [0 ], - 1 , mask .shape [2 ] // 8 , mask .shape [3 ] // 8 )
1151- mask = utils .resize_to_batch_size (mask , noise .shape [0 ])
1152- return torch .cat ((image , mask ), dim = 1 )
1153-
1154-
11551117 def extra_conds (self , ** kwargs ):
11561118 out = super ().extra_conds (** kwargs )
1157- cross_attn = kwargs .get ("cross_attn" , None )
1158- if cross_attn is not None :
1159- out ['c_crossattn' ] = comfy .conds .CONDRegular (cross_attn )
1160- # upscale the attention mask, since now we
1161- attention_mask = kwargs .get ("attention_mask" , None )
1162- if attention_mask is not None :
1163- shape = kwargs ["noise" ].shape
1164- mask_ref_size = kwargs ["attention_mask_img_shape" ]
1165- # the model will pad to the patch size, and then divide
1166- # essentially dividing and rounding up
1167- (h_tok , w_tok ) = (math .ceil (shape [2 ] / self .diffusion_model .patch_size ), math .ceil (shape [3 ] / self .diffusion_model .patch_size ))
1168- attention_mask = utils .upscale_dit_mask (attention_mask , mask_ref_size , (h_tok , w_tok ))
1169- out ['attention_mask' ] = comfy .conds .CONDRegular (attention_mask )
1170- guidance = 0.0
1171- out ['guidance' ] = comfy .conds .CONDRegular (torch .FloatTensor ((guidance ,)))
1119+
1120+ guidance = kwargs .get ("guidance" , 0 )
1121+ if guidance is not None :
1122+ out ['guidance' ] = comfy .conds .CONDRegular (torch .FloatTensor ([guidance ]))
11721123 return out
0 commit comments