Skip to content

Commit dc9822b

Browse files
authored
Add working Qwen 2512 ControlNet (Fun ControlNet) support (Comfy-Org#12359)
1 parent 712efb4 commit dc9822b

File tree

2 files changed

+263
-0
lines changed

2 files changed

+263
-0
lines changed

comfy/controlnet.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,30 @@ def cleanup(self):
297297
self.model_sampling_current = None
298298
super().cleanup()
299299

300+
301+
class QwenFunControlNet(ControlNet):
302+
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
303+
# Fun checkpoints are more sensitive to high strengths in the generic
304+
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
305+
# unchanged while >1 grows more gently.
306+
original_strength = self.strength
307+
self.strength = math.sqrt(max(self.strength, 0.0))
308+
try:
309+
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
310+
finally:
311+
self.strength = original_strength
312+
313+
def pre_run(self, model, percent_to_timestep_function):
314+
super().pre_run(model, percent_to_timestep_function)
315+
self.set_extra_arg("base_model", model.diffusion_model)
316+
317+
def copy(self):
318+
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
319+
c.control_model = self.control_model
320+
c.control_model_wrapped = self.control_model_wrapped
321+
self.copy_to(c)
322+
return c
323+
300324
class ControlLoraOps:
301325
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
302326
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -606,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
606630
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
607631
return control
608632

633+
634+
def load_controlnet_qwen_fun(sd, model_options={}):
635+
load_device = comfy.model_management.get_torch_device()
636+
weight_dtype = comfy.utils.weight_dtype(sd)
637+
unet_dtype = model_options.get("dtype", weight_dtype)
638+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
639+
640+
operations = model_options.get("custom_operations", None)
641+
if operations is None:
642+
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
643+
644+
in_features = sd["control_img_in.weight"].shape[1]
645+
inner_dim = sd["control_img_in.weight"].shape[0]
646+
647+
block_weight = sd["control_blocks.0.attn.to_q.weight"]
648+
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
649+
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
650+
651+
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
652+
control_in_features=in_features,
653+
inner_dim=inner_dim,
654+
num_attention_heads=num_attention_heads,
655+
attention_head_dim=attention_head_dim,
656+
num_control_blocks=5,
657+
main_model_double=60,
658+
injection_layers=(0, 12, 24, 36, 48),
659+
operations=operations,
660+
device=comfy.model_management.unet_offload_device(),
661+
dtype=unet_dtype,
662+
)
663+
model = controlnet_load_state_dict(model, sd)
664+
665+
latent_format = comfy.latent_formats.Wan21()
666+
control = QwenFunControlNet(
667+
model,
668+
compression_ratio=1,
669+
latent_format=latent_format,
670+
# Fun checkpoints already expect their own 33-channel context handling.
671+
# Enabling generic concat_mask injects an extra mask channel at apply-time
672+
# and breaks the intended fallback packing path.
673+
concat_mask=False,
674+
load_device=load_device,
675+
manual_cast_dtype=manual_cast_dtype,
676+
extra_conds=[],
677+
)
678+
return control
679+
609680
def convert_mistoline(sd):
610681
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
611682

@@ -683,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
683754
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
684755
elif "controlnet_x_embedder.weight" in controlnet_data:
685756
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
757+
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
758+
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
686759

687760
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
688761
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

comfy/ldm/qwen_image/controlnet.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,196 @@
22
import math
33

44
from .model import QwenImageTransformer2DModel
5+
from .model import QwenImageTransformerBlock
6+
7+
8+
class QwenImageFunControlBlock(QwenImageTransformerBlock):
9+
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
10+
super().__init__(
11+
dim=dim,
12+
num_attention_heads=num_attention_heads,
13+
attention_head_dim=attention_head_dim,
14+
dtype=dtype,
15+
device=device,
16+
operations=operations,
17+
)
18+
self.has_before_proj = has_before_proj
19+
if has_before_proj:
20+
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
21+
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
22+
23+
24+
class QwenImageFunControlNetModel(torch.nn.Module):
25+
def __init__(
26+
self,
27+
control_in_features=132,
28+
inner_dim=3072,
29+
num_attention_heads=24,
30+
attention_head_dim=128,
31+
num_control_blocks=5,
32+
main_model_double=60,
33+
injection_layers=(0, 12, 24, 36, 48),
34+
dtype=None,
35+
device=None,
36+
operations=None,
37+
):
38+
super().__init__()
39+
self.dtype = dtype
40+
self.main_model_double = main_model_double
41+
self.injection_layers = tuple(injection_layers)
42+
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
43+
# to the reference Gen2/VideoX implementation around strength=1.
44+
self.hint_scale = 1.0
45+
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
46+
47+
self.control_blocks = torch.nn.ModuleList([])
48+
for i in range(num_control_blocks):
49+
self.control_blocks.append(
50+
QwenImageFunControlBlock(
51+
dim=inner_dim,
52+
num_attention_heads=num_attention_heads,
53+
attention_head_dim=attention_head_dim,
54+
has_before_proj=(i == 0),
55+
dtype=dtype,
56+
device=device,
57+
operations=operations,
58+
)
59+
)
60+
61+
def _process_hint_tokens(self, hint):
62+
if hint is None:
63+
return None
64+
if hint.ndim == 4:
65+
hint = hint.unsqueeze(2)
66+
67+
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
68+
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
69+
# Default behavior (no inpaint input in stock Apply ControlNet) should use
70+
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
71+
expected_c = self.control_img_in.weight.shape[1] // 4
72+
if hint.shape[1] == 16 and expected_c == 33:
73+
zeros_mask = torch.zeros_like(hint[:, :1])
74+
zeros_inpaint = torch.zeros_like(hint)
75+
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
76+
77+
bs, c, t, h, w = hint.shape
78+
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
79+
orig_shape = hidden_states.shape
80+
hidden_states = hidden_states.view(
81+
orig_shape[0],
82+
orig_shape[1],
83+
orig_shape[-3],
84+
orig_shape[-2] // 2,
85+
2,
86+
orig_shape[-1] // 2,
87+
2,
88+
)
89+
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
90+
hidden_states = hidden_states.reshape(
91+
bs,
92+
t * ((h + 1) // 2) * ((w + 1) // 2),
93+
c * 4,
94+
)
95+
96+
expected_in = self.control_img_in.weight.shape[1]
97+
cur_in = hidden_states.shape[-1]
98+
if cur_in < expected_in:
99+
pad = torch.zeros(
100+
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
101+
device=hidden_states.device,
102+
dtype=hidden_states.dtype,
103+
)
104+
hidden_states = torch.cat([hidden_states, pad], dim=-1)
105+
elif cur_in > expected_in:
106+
hidden_states = hidden_states[:, :, :expected_in]
107+
108+
return hidden_states
109+
110+
def forward(
111+
self,
112+
x,
113+
timesteps,
114+
context,
115+
attention_mask=None,
116+
guidance: torch.Tensor = None,
117+
hint=None,
118+
transformer_options={},
119+
base_model=None,
120+
**kwargs,
121+
):
122+
if base_model is None:
123+
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
124+
125+
encoder_hidden_states_mask = attention_mask
126+
# Keep attention mask disabled inside Fun control blocks to mirror
127+
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
128+
encoder_hidden_states_mask = None
129+
130+
hidden_states, img_ids, _ = base_model.process_img(x)
131+
hint_tokens = self._process_hint_tokens(hint)
132+
if hint_tokens is None:
133+
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
134+
135+
if hint_tokens.shape[1] != hidden_states.shape[1]:
136+
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
137+
hint_tokens = hint_tokens[:, :max_tokens]
138+
hidden_states = hidden_states[:, :max_tokens]
139+
img_ids = img_ids[:, :max_tokens]
140+
141+
txt_start = round(
142+
max(
143+
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
144+
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
145+
)
146+
)
147+
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
148+
ids = torch.cat((txt_ids, img_ids), dim=1)
149+
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
150+
151+
hidden_states = base_model.img_in(hidden_states)
152+
encoder_hidden_states = base_model.txt_norm(context)
153+
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
154+
155+
if guidance is not None:
156+
guidance = guidance * 1000
157+
158+
temb = (
159+
base_model.time_text_embed(timesteps, hidden_states)
160+
if guidance is None
161+
else base_model.time_text_embed(timesteps, guidance, hidden_states)
162+
)
163+
164+
c = self.control_img_in(hint_tokens)
165+
166+
for i, block in enumerate(self.control_blocks):
167+
if i == 0:
168+
c_in = block.before_proj(c) + hidden_states
169+
all_c = []
170+
else:
171+
all_c = list(torch.unbind(c, dim=0))
172+
c_in = all_c.pop(-1)
173+
174+
encoder_hidden_states, c_out = block(
175+
hidden_states=c_in,
176+
encoder_hidden_states=encoder_hidden_states,
177+
encoder_hidden_states_mask=encoder_hidden_states_mask,
178+
temb=temb,
179+
image_rotary_emb=image_rotary_emb,
180+
transformer_options=transformer_options,
181+
)
182+
183+
c_skip = block.after_proj(c_out) * self.hint_scale
184+
all_c += [c_skip, c_out]
185+
c = torch.stack(all_c, dim=0)
186+
187+
hints = torch.unbind(c, dim=0)[:-1]
188+
189+
controlnet_block_samples = [None] * self.main_model_double
190+
for local_idx, base_idx in enumerate(self.injection_layers):
191+
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
192+
controlnet_block_samples[base_idx] = hints[local_idx]
193+
194+
return {"input": controlnet_block_samples}
5195

6196

7197
class QwenImageControlNetModel(QwenImageTransformer2DModel):

0 commit comments

Comments
 (0)