Skip to content

Commit 08ff5fa

Browse files
Cleanup chroma PR.
1 parent 4ca3d84 commit 08ff5fa

File tree

9 files changed

+25
-181
lines changed

9 files changed

+25
-181
lines changed

comfy/ldm/chroma/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import Tensor, nn
33

4-
from .math import attention
4+
from comfy.ldm.flux.math import attention
55
from comfy.ldm.flux.layers import (
66
MLPEmbedder,
77
RMSNorm,

comfy/ldm/chroma/math.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

comfy/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def model_lora_keys_unet(model, key_map={}):
252252
key_lora = k[len("diffusion_model."):-len(".weight")]
253253
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
254254

255-
if isinstance(model, comfy.model_base.Flux) or isinstance(model, comfy.model_base.Chroma): #Diffusers lora Flux or a diffusers lora Chroma
255+
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
256256
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
257257
for k in diffusers_keys:
258258
if k.endswith(".weight"):

comfy/model_base.py

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,8 @@ def extra_conds(self, **kwargs):
787787
return out
788788

789789
class Flux(BaseModel):
790-
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
791-
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
790+
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
791+
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
792792

793793
def concat_cond(self, **kwargs):
794794
try:
@@ -1110,63 +1110,14 @@ def extra_conds(self, **kwargs):
11101110
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
11111111
return out
11121112

1113-
class Chroma(BaseModel):
1113+
class Chroma(Flux):
11141114
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
11151115
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
11161116

1117-
def concat_cond(self, **kwargs):
1118-
try:
1119-
#Handle Flux control loras dynamically changing the img_in weight.
1120-
num_channels = self.diffusion_model.img_in.weight.shape[1]
1121-
except:
1122-
#Some cases like tensorrt might not have the weights accessible
1123-
num_channels = self.model_config.unet_config["in_channels"]
1124-
1125-
out_channels = self.model_config.unet_config["out_channels"]
1126-
1127-
if num_channels <= out_channels:
1128-
return None
1129-
1130-
image = kwargs.get("concat_latent_image", None)
1131-
noise = kwargs.get("noise", None)
1132-
device = kwargs["device"]
1133-
1134-
if image is None:
1135-
image = torch.zeros_like(noise)
1136-
1137-
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
1138-
image = utils.resize_to_batch_size(image, noise.shape[0])
1139-
image = self.process_latent_in(image)
1140-
if num_channels <= out_channels * 2:
1141-
return image
1142-
1143-
#inpaint model
1144-
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
1145-
if mask is None:
1146-
mask = torch.ones_like(noise)[:, :1]
1147-
1148-
mask = torch.mean(mask, dim=1, keepdim=True)
1149-
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
1150-
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
1151-
mask = utils.resize_to_batch_size(mask, noise.shape[0])
1152-
return torch.cat((image, mask), dim=1)
1153-
1154-
11551117
def extra_conds(self, **kwargs):
11561118
out = super().extra_conds(**kwargs)
1157-
cross_attn = kwargs.get("cross_attn", None)
1158-
if cross_attn is not None:
1159-
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
1160-
# upscale the attention mask, since now we
1161-
attention_mask = kwargs.get("attention_mask", None)
1162-
if attention_mask is not None:
1163-
shape = kwargs["noise"].shape
1164-
mask_ref_size = kwargs["attention_mask_img_shape"]
1165-
# the model will pad to the patch size, and then divide
1166-
# essentially dividing and rounding up
1167-
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
1168-
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
1169-
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
1170-
guidance = 0.0
1171-
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
1119+
1120+
guidance = kwargs.get("guidance", 0)
1121+
if guidance is not None:
1122+
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
11721123
return out

comfy/model_detection.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -154,32 +154,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
154154
dit_config["guidance_embed"] = len(guidance_keys) > 0
155155
return dit_config
156156

157-
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
158-
dit_config = {}
159-
dit_config["image_model"] = "chroma"
160-
dit_config["depth"] = 48
161-
dit_config["in_channels"] = 64
162-
patch_size = 2
163-
dit_config["patch_size"] = patch_size
164-
in_key = "{}img_in.weight".format(key_prefix)
165-
if in_key in state_dict_keys:
166-
dit_config["in_channels"] = state_dict[in_key].shape[1]
167-
dit_config["out_channels"] = 64
168-
dit_config["context_in_dim"] = 4096
169-
dit_config["hidden_size"] = 3072
170-
dit_config["mlp_ratio"] = 4.0
171-
dit_config["num_heads"] = 24
172-
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
173-
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
174-
dit_config["axes_dim"] = [16, 56, 56]
175-
dit_config["theta"] = 10000
176-
dit_config["qkv_bias"] = True
177-
dit_config["in_dim"] = 64
178-
dit_config["out_dim"] = 3072
179-
dit_config["hidden_dim"] = 5120
180-
dit_config["n_layers"] = 5
181-
return dit_config
182-
183157
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
184158
dit_config = {}
185159
dit_config["image_model"] = "flux"
@@ -190,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
190164
if in_key in state_dict_keys:
191165
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
192166
dit_config["out_channels"] = 16
193-
dit_config["vec_in_dim"] = 768
167+
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
168+
if vec_in_key in state_dict_keys:
169+
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
194170
dit_config["context_in_dim"] = 4096
195171
dit_config["hidden_size"] = 3072
196172
dit_config["mlp_ratio"] = 4.0
@@ -200,7 +176,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
200176
dit_config["axes_dim"] = [16, 56, 56]
201177
dit_config["theta"] = 10000
202178
dit_config["qkv_bias"] = True
203-
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
179+
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
180+
dit_config["image_model"] = "chroma"
181+
dit_config["in_channels"] = 64
182+
dit_config["out_channels"] = 64
183+
dit_config["in_dim"] = 64
184+
dit_config["out_dim"] = 3072
185+
dit_config["hidden_dim"] = 5120
186+
dit_config["n_layers"] = 5
187+
else:
188+
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
204189
return dit_config
205190

206191
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview

comfy/sd.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import comfy.text_encoders.lumina2
4343
import comfy.text_encoders.wan
4444
import comfy.text_encoders.hidream
45-
import comfy.text_encoders.chroma
4645

4746
import comfy.model_patcher
4847
import comfy.lora
@@ -820,7 +819,7 @@ class EmptyClass:
820819
elif clip_type == CLIPType.LTXV:
821820
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
822821
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
823-
elif clip_type == CLIPType.PIXART:
822+
elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
824823
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
825824
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
826825
elif clip_type == CLIPType.WAN:
@@ -831,9 +830,6 @@ class EmptyClass:
831830
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
832831
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
833832
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
834-
elif clip_type == CLIPType.CHROMA:
835-
clip_target.clip = comfy.text_encoders.chroma.chroma_te(**t5xxl_detect(clip_data))
836-
clip_target.tokenizer = comfy.text_encoders.chroma.ChromaT5Tokenizer
837833
else: #CLIPType.MOCHI
838834
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
839835
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer

comfy/supported_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import comfy.text_encoders.cosmos
1818
import comfy.text_encoders.lumina2
1919
import comfy.text_encoders.wan
20-
import comfy.text_encoders.chroma
2120

2221
from . import supported_models_base
2322
from . import latent_formats
@@ -1095,7 +1094,7 @@ def get_model(self, state_dict, prefix="", device=None):
10951094
def clip_target(self, state_dict={}):
10961095
pref = self.text_encoder_key_prefix[0]
10971096
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
1098-
return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect))
1097+
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
10991098

11001099
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
11011100

comfy/text_encoders/chroma.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

comfy_extras/nodes_optimalsteps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def loglinear_interp(t_steps, num_steps):
2020

2121
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
2222
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
23-
"Chroma": [0.9919999837875366, 0.9900000095367432, 0.9879999756813049, 0.9850000143051147, 0.9819999933242798, 0.9779999852180481, 0.9729999899864197, 0.9679999947547913, 0.9610000252723694, 0.953000009059906, 0.9430000185966492, 0.9309999942779541, 0.9169999957084656, 0.8999999761581421, 0.8809999823570251, 0.8579999804496765, 0.8320000171661377, 0.8019999861717224, 0.7689999938011169, 0.7310000061988831, 0.6899999976158142, 0.6460000276565552, 0.5989999771118164, 0.550000011920929, 0.5009999871253967, 0.45100000500679016, 0.4020000100135803, 0.35499998927116394, 0.3109999895095825, 0.27000001072883606, 0.23199999332427979, 0.19900000095367432, 0.16899999976158142, 0.14300000667572021, 0.11999999731779099, 0.10100000351667404, 0.08399999886751175, 0.07000000029802322, 0.057999998331069946, 0.04800000041723251, 0.0],
23+
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
2424
}
2525

2626
class OptimalStepsScheduler:

0 commit comments

Comments
 (0)