-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
follow-up refactor on lumina2 #10776
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
encoder_hidden_states = layer( | ||
encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb | ||
) | ||
encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the slight difference we see in the output without the mask actually coming from here; I didn't see it has an effect in speed so I set it to always use encoder_attention_mask
for the context_refiner
layers,
with this ,for single-prompt, we are getting identical output for use_mask_in_transformer=True
and use_mask_in_transformer=False
;
testing script
# test lumina2
import torch
from diffusers import Lumina2Text2ImgPipeline
import itertools
from pathlib import Path
import shutil
device ="cuda:1"
branch = "refactor_lumina2"
# branch = "main"
params = {
'use_mask_in_transformer': [True, False],
}
# Generate all combinations
param_combinations = list(itertools.product(*params.values()))
# Create output directory (remove if exists)
output_dir = Path(f"yiyi_test_6_outputs_{branch}")
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)
pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16).to(device)
prompt = [
"focused exterior view on a living room of a limestone rock regular blocks shaped villa with sliding windows and timber screens in Provence along the cliff facing the sea, with waterfalls from the roof to the pool, designed by Zaha Hadid, with rocky textures and form, made of regular giant rock blocks stacked each other with infinity edge pool in front of it, blends in with the surrounding nature. Regular rock blocks. Giant rock blocks shaping the space. The image to capture the infinity edge profile of the pool and the flow of water going down creating a waterfall effect. Adriatic Sea. The design is sustainable and semi prefab. The photo is shot on a canon 5D mark 4",
# "A capybara holding a sign that reads Hello World"
]
# Run test for each combination
for (mask,) in param_combinations:
print(f"\nTesting combination:")
print(f" use_mask_in_transformer: {mask}")
# Generate image
generator = torch.Generator(device=device).manual_seed(0)
images = pipe(
prompt=prompt,
num_inference_steps=25,
use_mask_in_transformer=mask,
generator=generator,
).images
# Save images
for i, image in enumerate(images):
output_path = output_dir / f"output_mask{int(mask)}_prompt{i}.png"
image.save(output_path)
print(f"Saved to: {output_path}")
![]() |
![]() |
nice! I can confirm that I get the same image which in turn makes it better for text generation when it failed some times before without a mask. |
@asomoza @a-r-r-o-w @hlky |
Sounds good @yiyixuxu |
This PR:
use_mask_in_transformer
to beFalse
because:use_mask_in_transformer=False
(see details follow-up refactor on lumina2 #10776 (comment))