Skip to content

Commit ffdd53b

Browse files
Check state dict key to auto enable the index_timestep_zero ref method. (Comfy-Org#11362)
1 parent 65e2103 commit ffdd53b

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ def __init__(
363363
for _ in range(num_layers)
364364
])
365365

366+
if self.default_ref_method == "index_timestep_zero":
367+
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
368+
366369
if final_layer:
367370
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
368371
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)

comfy/model_detection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
259259
dit_config["nerf_tile_size"] = 512
260260
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
261261
dit_config["nerf_embedder_dtype"] = torch.float32
262-
if "__x0__" in state_dict_keys: # x0 pred
262+
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
263263
dit_config["use_x0"] = True
264264
else:
265265
dit_config["use_x0"] = False
@@ -618,6 +618,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
618618
dit_config["image_model"] = "qwen_image"
619619
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
620620
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
621+
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
622+
dit_config["default_ref_method"] = "index_timestep_zero"
621623
return dit_config
622624

623625
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

0 commit comments

Comments
 (0)