Skip to content

Commit a3904d7

Browse files
gnobitabxingchaoliuyiyixuxu
authored
[Tencent Hunyuan Team] Add HunyuanDiT-v1.2 Support (#8747)
* add v1.2 support --------- Co-authored-by: xingchaoliu <[email protected]> Co-authored-by: yiyixuxu <[email protected]>
1 parent 7bfc1ee commit a3904d7

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

src/diffusers/models/embeddings.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,14 @@ def forward(self, x):
717717

718718

719719
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
720-
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
720+
def __init__(
721+
self,
722+
embedding_dim,
723+
pooled_projection_dim=1024,
724+
seq_len=256,
725+
cross_attention_dim=2048,
726+
use_style_cond_and_image_meta_size=True,
727+
):
721728
super().__init__()
722729

723730
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -726,9 +733,15 @@ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross
726733
self.pooler = HunyuanDiTAttentionPool(
727734
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
728735
)
736+
729737
# Here we use a default learned embedder layer for future extension.
730-
self.style_embedder = nn.Embedding(1, embedding_dim)
731-
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
738+
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
739+
if use_style_cond_and_image_meta_size:
740+
self.style_embedder = nn.Embedding(1, embedding_dim)
741+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
742+
else:
743+
extra_in_dim = pooled_projection_dim
744+
732745
self.extra_embedder = PixArtAlphaTextProjection(
733746
in_features=extra_in_dim,
734747
hidden_size=embedding_dim * 4,
@@ -743,16 +756,20 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
743756
# extra condition1: text
744757
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
745758

746-
# extra condition2: image meta size embdding
747-
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
748-
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
749-
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
759+
if self.use_style_cond_and_image_meta_size:
760+
# extra condition2: image meta size embdding
761+
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
762+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
763+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
750764

751-
# extra condition3: style embedding
752-
style_embedding = self.style_embedder(style) # (N, embedding_dim)
765+
# extra condition3: style embedding
766+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
767+
768+
# Concatenate all extra vectors
769+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
770+
else:
771+
extra_cond = torch.cat([pooled_projections], dim=1)
753772

754-
# Concatenate all extra vectors
755-
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
756773
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
757774

758775
return conditioning

src/diffusers/models/transformers/hunyuan_transformer_2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
249249
The length of the clip text embedding.
250250
text_len_t5 (`int`, *optional*):
251251
The length of the T5 text embedding.
252+
use_style_cond_and_image_meta_size (`bool`, *optional*):
253+
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
252254
"""
253255

254256
@register_to_config
@@ -270,6 +272,7 @@ def __init__(
270272
pooled_projection_dim: int = 1024,
271273
text_len: int = 77,
272274
text_len_t5: int = 256,
275+
use_style_cond_and_image_meta_size: bool = True,
273276
):
274277
super().__init__()
275278
self.out_channels = in_channels * 2 if learn_sigma else in_channels
@@ -301,6 +304,7 @@ def __init__(
301304
pooled_projection_dim=pooled_projection_dim,
302305
seq_len=text_len_t5,
303306
cross_attention_dim=cross_attention_dim_t5,
307+
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
304308
)
305309

306310
# HunyuanDiT Blocks

0 commit comments

Comments
 (0)