Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/audio_encoders/audio_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def encode_audio(self, audio, sample_rate):
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
outputs["audio_samples"] = audio.shape[2]
return outputs


Expand Down
11 changes: 5 additions & 6 deletions comfy/ldm/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)

def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
return x_out.reshape(*x.shape).type_as(x)

def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
279 changes: 266 additions & 13 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
Expand All @@ -34,7 +34,9 @@ def __init__(self,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
eps=1e-6,
kv_dim=None,
operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
Expand All @@ -43,11 +45,13 @@ def __init__(self,
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
if kv_dim is None:
kv_dim = dim

# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
Expand All @@ -60,20 +64,24 @@ def forward(self, x, freqs, transformer_options={}):
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim

# query, key, value function
def qkv_fn(x):
def qkv_fn_q(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
return apply_rope1(q, freqs)

def qkv_fn_k(x):
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n * d)
return q, k, v
return apply_rope1(k, freqs)

q, k, v = qkv_fn(x)
q, k = apply_rope(q, k, freqs)
#These two are VRAM hogs, so we want to do all of q computation and
#have pytorch garbage collect the intermediates on the sub function
#return before we touch k
q = qkv_fn_q(x)
k = qkv_fn_k(x)

x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
self.v(x).view(b, s, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
Expand Down Expand Up @@ -398,6 +406,7 @@ def __init__(self,
eps=1e-6,
flf_pos_embed_token_number=None,
in_dim_ref_conv=None,
wan_attn_block_class=WanAttentionBlock,
image_model=None,
device=None,
dtype=None,
Expand Down Expand Up @@ -475,8 +484,8 @@ def __init__(self,
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])

Expand Down Expand Up @@ -1321,3 +1330,247 @@ def block_wrap(args):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x


class WanT2VCrossAttentionGather(WanSelfAttention):

def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C] - video tokens
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim

q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)

# Handle audio temporal structure (16 tokens per frame)
k = k.reshape(-1, 16, n, d).transpose(1, 2)
v = v.reshape(-1, 16, n, d).transpose(1, 2)

# Handle video spatial structure
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)

x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)

x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
x = self.o(x)
return x


class AudioCrossAttentionWrapper(nn.Module):
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
super().__init__()

self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings)
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))

def forward(self, x, audio, transformer_options={}):
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
return x


class WanAttentionBlockAudio(WanAttentionBlock):

def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)

def forward(
self,
x,
e,
freqs,
context,
context_img_len=257,
audio=None,
transformer_options={},
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32

if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32

# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)

x = torch.addcmul(x, y, repeat_e(e[2], x))

# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if audio is not None:
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x

class DummyAdapterLayer(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer

def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)


class AudioProjModel(nn.Module):
def __init__(
self,
seq_len=5,
blocks=13, # add a new parameter blocks
channels=768, # add a new parameter channels
intermediate_dim=512,
output_dim=1536,
context_tokens=16,
device=None,
dtype=None,
operations=None,
):
super().__init__()

self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.output_dim = output_dim

# define multiple linear layers
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))

self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))

def forward(self, audio_embeds):
video_length = audio_embeds.shape[1]
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)

audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))

context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)

context_tokens = self.audio_proj_glob_norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)

return context_tokens


class HumoWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""

def __init__(self,
model_type='humo',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
audio_token_num=16,
device=None,
dtype=None,
operations=None,
):

super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)

self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)

def forward_orig(
self,
x,
t,
context,
freqs=None,
audio_embed=None,
reference_latent=None,
transformer_options={},
**kwargs,
):
bs, _, time, height, width = x.shape

# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)

# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))

if reference_latent is not None:
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
del ref, freqs_ref

# context
context = self.text_embedding(context)
context_img_len = None

if audio_embed is not None:
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
else:
audio = None

patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)

# head
x = self.head(x, e)

# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
17 changes: 17 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,23 @@ def extra_conds(self, **kwargs):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out

class WAN21_HuMo(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
self.image_to_video = image_to_video

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)

audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)

reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
return out

class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
Expand Down
2 changes: 2 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "camera_2.2"
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "s2v"
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "humo"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
Expand Down
Loading
Loading