Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/ldm/chroma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def forward_orig(
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})

# running on sequences img
Expand Down Expand Up @@ -228,6 +229,7 @@ def block_wrap(args):

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:
Expand Down
18 changes: 18 additions & 0 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec

transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()

# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
Expand Down Expand Up @@ -224,6 +227,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]

# calculate the img bloks
Expand Down Expand Up @@ -303,6 +312,9 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
else:
mod = vec

transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()

qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)

q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
Expand All @@ -312,6 +324,12 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)

# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
Expand Down
2 changes: 2 additions & 0 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def forward_orig(
attn_mask: Tensor = None,
) -> Tensor:

transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
Expand Down Expand Up @@ -231,6 +232,7 @@ def block_wrap(args):

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
Expand Down
2 changes: 2 additions & 0 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def forward_orig(
control=None,
transformer_options={},
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})

initial_shape = list(img.shape)
Expand Down Expand Up @@ -416,6 +417,7 @@ def block_wrap(args):

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
Expand Down
18 changes: 10 additions & 8 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,7 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)

context = c_crossattn
dtype = self.get_dtype()

if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()

xc = xc.to(dtype)
device = xc.device
Expand Down Expand Up @@ -218,6 +215,13 @@ def process_timestep(self, timestep, **kwargs):
def get_dtype(self):
return self.diffusion_model.dtype

def get_dtype_inference(self):
dtype = self.get_dtype()

if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
return dtype

def encode_adm(self, **kwargs):
return None

Expand Down Expand Up @@ -372,9 +376,7 @@ def memory_required(self, input_shape, cond_shapes={}):
input_shapes += shape

if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
Expand Down Expand Up @@ -1165,7 +1167,7 @@ def extra_conds(self, **kwargs):
t5xxl_ids = t5xxl_ids.unsqueeze(0)

if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
Expand Down
5 changes: 4 additions & 1 deletion comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,16 @@ def clone_has_same_weights(self, clone: 'ModelPatcher'):
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)

def disable_model_cfg1_optimization(self):
self.model_options["disable_cfg1_optimization"] = True

def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
self.disable_model_cfg1_optimization()

def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
Expand Down
99 changes: 99 additions & 0 deletions comfy_extras/nodes_nag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override


class NAGuidance(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="NAGuidance",
display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."),
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
],
outputs=[
io.Model.Output(tooltip="The patched model with NAG enabled."),
],
)

@classmethod
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
m = model.clone()

# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)

def nag_attention_output_patch(out, extra_options):
cond_or_uncond = extra_options.get("cond_or_uncond", None)
if cond_or_uncond is None:
return out

if not (1 in cond_or_uncond and 0 in cond_or_uncond):
return out

# sigma = extra_options.get("sigmas", None)
# if sigma is not None and len(sigma) > 0:
# sigma = sigma[0].item()
# if sigma > sigma_start or sigma < sigma_end:
# return out

img_slice = extra_options.get("img_slice", None)

if img_slice is not None:
orig_out = out
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part

batch_size = out.shape[0]
half_size = batch_size // len(cond_or_uncond)

ind_neg = cond_or_uncond.index(1)
ind_pos = cond_or_uncond.index(0)
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]

guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)

eps = 1e-6
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)

ratio = norm_guided / norm_pos
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio

guided_normalized = guided * scale_factor

z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)

if img_slice is not None:
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
return orig_out
else:
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
return out

m.set_model_attn1_output_patch(nag_attention_output_patch)
m.disable_model_cfg1_optimization()

return io.NodeOutput(m)


class NagExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
NAGuidance,
]


async def comfy_entrypoint() -> NagExtension:
return NagExtension()
2 changes: 1 addition & 1 deletion comfyui_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.13.0"
__version__ = "0.14.0"
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2437,6 +2437,7 @@ async def init_builtin_extra_nodes():
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
]

import_failed = []
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.13.0"
version = "0.14.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
Expand Down
Loading