@@ -717,7 +717,14 @@ def forward(self, x):
717
717
718
718
719
719
class HunyuanCombinedTimestepTextSizeStyleEmbedding (nn .Module ):
720
- def __init__ (self , embedding_dim , pooled_projection_dim = 1024 , seq_len = 256 , cross_attention_dim = 2048 ):
720
+ def __init__ (
721
+ self ,
722
+ embedding_dim ,
723
+ pooled_projection_dim = 1024 ,
724
+ seq_len = 256 ,
725
+ cross_attention_dim = 2048 ,
726
+ use_style_cond_and_image_meta_size = True ,
727
+ ):
721
728
super ().__init__ ()
722
729
723
730
self .time_proj = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
@@ -726,9 +733,15 @@ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross
726
733
self .pooler = HunyuanDiTAttentionPool (
727
734
seq_len , cross_attention_dim , num_heads = 8 , output_dim = pooled_projection_dim
728
735
)
736
+
729
737
# Here we use a default learned embedder layer for future extension.
730
- self .style_embedder = nn .Embedding (1 , embedding_dim )
731
- extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
738
+ self .use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
739
+ if use_style_cond_and_image_meta_size :
740
+ self .style_embedder = nn .Embedding (1 , embedding_dim )
741
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
742
+ else :
743
+ extra_in_dim = pooled_projection_dim
744
+
732
745
self .extra_embedder = PixArtAlphaTextProjection (
733
746
in_features = extra_in_dim ,
734
747
hidden_size = embedding_dim * 4 ,
@@ -743,16 +756,20 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
743
756
# extra condition1: text
744
757
pooled_projections = self .pooler (encoder_hidden_states ) # (N, 1024)
745
758
746
- # extra condition2: image meta size embdding
747
- image_meta_size = get_timestep_embedding (image_meta_size .view (- 1 ), 256 , True , 0 )
748
- image_meta_size = image_meta_size .to (dtype = hidden_dtype )
749
- image_meta_size = image_meta_size .view (- 1 , 6 * 256 ) # (N, 1536)
759
+ if self .use_style_cond_and_image_meta_size :
760
+ # extra condition2: image meta size embdding
761
+ image_meta_size = get_timestep_embedding (image_meta_size .view (- 1 ), 256 , True , 0 )
762
+ image_meta_size = image_meta_size .to (dtype = hidden_dtype )
763
+ image_meta_size = image_meta_size .view (- 1 , 6 * 256 ) # (N, 1536)
750
764
751
- # extra condition3: style embedding
752
- style_embedding = self .style_embedder (style ) # (N, embedding_dim)
765
+ # extra condition3: style embedding
766
+ style_embedding = self .style_embedder (style ) # (N, embedding_dim)
767
+
768
+ # Concatenate all extra vectors
769
+ extra_cond = torch .cat ([pooled_projections , image_meta_size , style_embedding ], dim = 1 )
770
+ else :
771
+ extra_cond = torch .cat ([pooled_projections ], dim = 1 )
753
772
754
- # Concatenate all extra vectors
755
- extra_cond = torch .cat ([pooled_projections , image_meta_size , style_embedding ], dim = 1 )
756
773
conditioning = timesteps_emb + self .extra_embedder (extra_cond ) # [B, D]
757
774
758
775
return conditioning
0 commit comments