Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for lumina2 #10642

Merged
merged 28 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"LuminaNextDiT2DModel",
"Lumina2Transformer2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
Expand Down Expand Up @@ -329,6 +330,7 @@
"LTXImageToVideoPipeline",
"LTXPipeline",
"LuminaText2ImgPipeline",
"Lumina2Text2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
Expand Down Expand Up @@ -622,6 +624,7 @@
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
Lumina2Transformer2DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
Expand Down Expand Up @@ -820,6 +823,7 @@
LTXImageToVideoPipeline,
LTXPipeline,
LuminaText2ImgPipeline,
Lumina2Text2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
Expand Down Expand Up @@ -139,6 +140,7 @@
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
Lumina2Transformer2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
PriorTransformer,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,6 @@ def __init__(
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
inner_dim = int(2 * inner_dim / 3)
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
Expand Down
97 changes: 97 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4192,6 +4192,102 @@ def __call__(
return hidden_states


class Lumina2AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[torch.Tensor] = None,
key_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = hidden_states.shape

# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype

# Get key-value heads
kv_heads = inner_dim // head_dim

query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)

# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
if key_rotary_emb is not None:
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)

query, key = query.to(dtype), key.to(dtype)

# Apply proportional attention if true
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale

# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)

# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)

return hidden_states


class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
Expand Down Expand Up @@ -6183,6 +6279,7 @@ def __call__(
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
Lumina2AttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
Expand Down
32 changes: 32 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,38 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
return conditioning


class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5):
super().__init__()

from .normalization import RMSNorm

self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
)

self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024))

self.caption_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
hidden_size,
bias=True,
),
)

def forward(self, timestep, caption_feat):
# timestep embedding:
time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))

# caption condition embedding:
caption_embed = self.caption_embedder(caption_feat)

return time_embed, caption_embed


class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
super().__init__()
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,13 @@ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine:
4 * embedding_dim,
bias=True,
)
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm = RMSNorm(embedding_dim, eps=norm_eps)

def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/lumina_nextdit2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(

self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
inner_dim=int(4 * 2 * dim / 3),
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
Expand Down
Loading