Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

follow-up refactor on lumina2 #10776

Merged
merged 9 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 82 additions & 107 deletions src/diffusers/models/transformers/transformer_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,97 +241,85 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,

def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
freqs_cis = []
# Use float32 for MPS compatibility
dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype)
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
freqs_cis.append(emb)
return freqs_cis

def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
device = ids.device
if ids.device.type == "mps":
ids = ids.to("cpu")

result = []
for i in range(len(self.axes_dim)):
freqs = self.freqs_cis[i].to(ids.device)
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1)
return torch.cat(result, dim=-1).to(device)

def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
batch_size = len(hidden_states)
p_h = p_w = self.patch_size
device = hidden_states[0].device
batch_size, channels, height, width = hidden_states.shape
p = self.patch_size
post_patch_height, post_patch_width = height // p, width // p
image_seq_len = post_patch_height * post_patch_width
device = hidden_states.device

encoder_seq_len = attention_mask.shape[1]
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]

max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
max_img_len = max(l_effective_img_len)
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
max_seq_len = max(seq_lengths)

# Create position IDs
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)

for i in range(batch_size):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // p_h, W // p_w
assert H_tokens * W_tokens == img_len
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
# add caption position ids
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len

position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
# add image position ids
row_ids = (
torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
torch.arange(post_patch_height, dtype=torch.int32, device=device)
.view(-1, 1)
.repeat(1, post_patch_width)
.flatten()
)
col_ids = (
torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
torch.arange(post_patch_width, dtype=torch.int32, device=device)
.view(1, -1)
.repeat(post_patch_height, 1)
.flatten()
)
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
position_ids[i, cap_seq_len:seq_len, 1] = row_ids
position_ids[i, cap_seq_len:seq_len, 2] = col_ids

# Get combined rotary embeddings
freqs_cis = self._get_freqs_cis(position_ids)

cap_freqs_cis_shape = list(freqs_cis.shape)
cap_freqs_cis_shape[1] = attention_mask.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)

img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)

for i in range(batch_size):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]

flat_hidden_states = []
for i in range(batch_size):
img = hidden_states[i]
C, H, W = img.size()
img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_hidden_states.append(img)
hidden_states = flat_hidden_states
padded_img_embed = torch.zeros(
batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
# create separate rotary embeddings for captions and images
cap_freqs_cis = torch.zeros(
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
)
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
for i in range(batch_size):
padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i]
padded_img_mask[i, : l_effective_img_len[i]] = True

return (
padded_img_embed,
padded_img_mask,
img_sizes,
l_effective_cap_len,
l_effective_img_len,
freqs_cis,
cap_freqs_cis,
img_freqs_cis,
max_seq_len,
img_freqs_cis = torch.zeros(
batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
)

for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]

# image patch embeddings
hidden_states = (
hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
.permute(0, 2, 4, 3, 5, 1)
.flatten(3)
.flatten(1, 2)
)

return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths


class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
Expand Down Expand Up @@ -471,75 +459,62 @@ def forward(
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
use_mask_in_transformer: bool = True,
encoder_attention_mask: torch.Tensor,
use_mask: bool = True,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
batch_size = hidden_states.size(0)

# 1. Condition, positional & patch embedding
batch_size, _, height, width = hidden_states.shape

temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)

(
hidden_states,
hidden_mask,
hidden_sizes,
encoder_hidden_len,
hidden_len,
joint_rotary_emb,
encoder_rotary_emb,
hidden_rotary_emb,
max_seq_len,
) = self.rope_embedder(hidden_states, attention_mask)
context_rotary_emb,
noise_rotary_emb,
rotary_emb,
encoder_seq_lengths,
seq_lengths,
) = self.rope_embedder(hidden_states, encoder_attention_mask)

hidden_states = self.x_embedder(hidden_states)

# 2. Context & noise refinement
for layer in self.context_refiner:
# NOTE: mask not used for performance
encoder_hidden_states = layer(
encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb
)
encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the slight difference we see in the output without the mask actually coming from here; I didn't see it has an effect in speed so I set it to always use encoder_attention_mask for the context_refiner layers,

with this ,for single-prompt, we are getting identical output for use_mask_in_transformer=True and use_mask_in_transformer=False;

testing script
# test lumina2
import torch
from diffusers import Lumina2Text2ImgPipeline
import itertools
from pathlib import Path
import shutil


device ="cuda:1"

branch = "refactor_lumina2"
# branch = "main"
params = {
    'use_mask_in_transformer': [True, False],  
}

# Generate all combinations
param_combinations = list(itertools.product(*params.values()))

# Create output directory (remove if exists)
output_dir = Path(f"yiyi_test_6_outputs_{branch}")
if output_dir.exists():
    shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)

pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16).to(device)

prompt = [
    "focused exterior view on a living room of a limestone rock regular blocks shaped villa with sliding windows and timber screens in Provence along the cliff facing the sea, with waterfalls from the roof to the pool, designed by Zaha Hadid, with rocky textures and form, made of regular giant rock blocks stacked each other with infinity edge pool in front of it, blends in with the surrounding nature. Regular rock blocks. Giant rock blocks shaping the space. The image to capture the infinity edge profile of the pool and the flow of water going down creating a waterfall effect. Adriatic Sea. The design is sustainable and semi prefab. The photo is shot on a canon 5D mark 4",
    # "A capybara holding a sign that reads Hello World"
]

# Run test for each combination
for (mask,) in param_combinations:
    print(f"\nTesting combination:")
    print(f"  use_mask_in_transformer: {mask}")
    
    # Generate image
    generator = torch.Generator(device=device).manual_seed(0)
    images = pipe(
        prompt=prompt,
        num_inference_steps=25,
        use_mask_in_transformer=mask,
        generator=generator,
    ).images
    
    # Save images
    for i, image in enumerate(images):
        output_path = output_dir / f"output_mask{int(mask)}_prompt{i}.png"
        image.save(output_path)
        print(f"Saved to: {output_path}")


for layer in self.noise_refiner:
# NOTE: mask not used for performance
hidden_states = layer(
hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb
)
hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)

# 3. Joint Transformer blocks
max_seq_len = max(seq_lengths)
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
attention_mask[i, :seq_len] = True
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]

hidden_states = joint_hidden_states

# 3. Attention mask preparation
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i in range(batch_size):
cap_len = encoder_hidden_len[i]
img_len = hidden_len[i]
mask[i, : cap_len + img_len] = True
padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
hidden_states = padded_hidden_states

# 4. Transformer blocks
for layer in self.layers:
# NOTE: mask not used for performance
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb
layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
)
else:
hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb)
hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)

# 5. Output norm & projection & unpatchify
# 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)

height_tokens = width_tokens = self.config.patch_size
# 5. Unpatchify
p = self.config.patch_size
output = []
for i in range(len(hidden_sizes)):
height, width = hidden_sizes[i]
begin = encoder_hidden_len[i]
end = begin + (height // height_tokens) * (width // width_tokens)
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
output.append(
hidden_states[i][begin:end]
.view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels)
hidden_states[i][encoder_seq_len:seq_len]
.view(height // p, width // p, p, p, self.out_channels)
.permute(4, 0, 2, 1, 3)
.flatten(3, 4)
.flatten(1, 2)
Expand Down
21 changes: 7 additions & 14 deletions src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
Expand All @@ -44,12 +42,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


if is_bs4_available():
pass

if is_ftfy_available():
pass

EXAMPLE_DOC_STRING = """
Examples:
```py
Expand Down Expand Up @@ -525,7 +517,7 @@ def __call__(
system_prompt: Optional[str] = None,
cfg_trunc_ratio: float = 1.0,
cfg_normalization: bool = True,
use_mask_in_transformer: bool = True,
use_mask_in_transformer: bool = False,
max_sequence_length: int = 256,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Expand Down Expand Up @@ -598,7 +590,8 @@ def __call__(
cfg_normalization (`bool`, *optional*, defaults to `True`):
Whether to apply normalization-based guidance scale.
use_mask_in_transformer (`bool`, *optional*, defaults to `True`):
Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain.
Whether to use attention mask in `Lumina2Transformer2DModel` for the transformer blocks. Only need to
set `True` when you pass a list of prompts with different lengths.
max_sequence_length (`int`, defaults to `256`):
Maximum sequence length to use with the `prompt`.

Expand Down Expand Up @@ -704,8 +697,8 @@ def __call__(
hidden_states=latents,
timestep=current_timestep,
encoder_hidden_states=prompt_embeds,
attention_mask=prompt_attention_mask,
use_mask_in_transformer=use_mask_in_transformer,
encoder_attention_mask=prompt_attention_mask,
use_mask=use_mask_in_transformer,
return_dict=False,
)[0]

Expand All @@ -715,8 +708,8 @@ def __call__(
hidden_states=latents,
timestep=current_timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_mask=negative_prompt_attention_mask,
use_mask_in_transformer=use_mask_in_transformer,
encoder_attention_mask=negative_prompt_attention_mask,
use_mask=use_mask_in_transformer,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
Expand Down
Loading