Skip to content

[WIP]correct the attn naming for UNet3DConditionModel #6873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Feb 6, 2024

why do I open this PR?

The motivation for this PR is that I find it impossible to reason about how to config a 3D Unet using CrossAttnDownBlock3D, CrossAttnUpBlock3D and UNetMidBlock3DCrossAttn - these 3 blocks all have an argument num_attention_heads which expects the value of attention_head_dim.

for example this is the __init__ method of CrossAttnDownBlock3D

class CrossAttnDownBlock3D(nn.Module):
    def __init__(self, ..., num_attention_heads, ...):
              ...
              attentions.append(
                Transformer2DModel(
                    out_channels // num_attention_heads,
                    num_attention_heads,
                    in_channels=out_channels,
                    num_layers=1,
                    cross_attention_dim=cross_attention_dim,
                    norm_num_groups=resnet_groups,
                    use_linear_projection=use_linear_projection,
                    only_cross_attention=only_cross_attention,
                    upcast_attention=upcast_attention,
                )
            )

It is not really obvious that the num_attention_heads are supposed to be attention_head_dim. Only if you look closely at how it creates Transformer2DModel and then check against Transformer2DModel's signature, you will notice that the num_attention_heads here are passed as attention_head_dim to Transformer2DModel and then passed all the way down to Attention as head_dim

All our text-to-video UNets are configured with attention_head_dim, for example this one has attention_head_dim = 64

so what we did in UNet3DConditionModel is

  1. we immediately assign attention_head_dim to num_attention_head, i.e. insideUNet3DConditionModel.__init__, we do num_attention_head=attention_head_dim
  2. we pass the 64 around in the name of num_attention_heads:
    • call get_down_block(num_attention_heads = num_attention_heads)
    • get_down_block then call CrossAttnDownBlock3D(num_attention_heads=num_attention_heads)
  3. we swap these two arguments back when CrossAttnDownBlock3D calls Transformer2DModel and TransformerTemporalModel

It took me so much efforts to figure out what's going on and and I'm still confused. I know this is introduced in this PR #3797 but I'm not sure why.

Unlike Unet2D models, 3D models never really have the "wrong configuration" problem - they are configured with attention_head_dim instead of num_attention_heads, but it wan't "wrong". i.e. attention_head_dim = 64 actually means attention_head_dim = 64. This is different from Unet2D models, I'm aware that the Unet2D has the wrong configuration issue: the attention_head_dim in their config file should be num_attention_heads; and I'm aware that's why we had to assign attention_head_dim to num_attention_heads for Unet2D. But this is not the case for 3D though

Did we do it this way so that all the cross-attention blocks can only accept one argument, num_attention_heads, instead of two different arguments? If so, I would argue that even though it is not ideal, it is still better than the current arrangement: with current arrangement I find it very difficult to reason about it and even harder to explain to other people 😭😭😭

So, I tried to correct the argument names and deprecate things in this PR. I'm curious why we did it this way (I think it's very likely I missed something). And I'm very much open to any other solution that can make this confusion go away:)

test

I will run the slow test, but here is a quick sanity check to make sure that we are able to config the heads and head_dim parameters in Attention class correctly.

import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention

print(" ")
print(" unet2d")
repo = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16, variant="fp16")
for name, module in pipe.unet.named_modules():
    if isinstance(module, Attention):
        print(f"  module.inner_dim/module_heads:{module.inner_dim/module.heads}, module.heads:{module.heads}")
        print(f" module.scale: {module.scale}")

print(f" ")
print(f" unet3d")
repo = "damo-vilab/text-to-video-ms-1.7b"
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16, variant="fp16")

for name, module in pipe.unet.named_modules():
    if isinstance(module, Attention):
        print(f"  module.inner_dim/module_heads:{module.inner_dim/module.heads}, module.heads:{module.heads}")
        print(f" module.scale: {module.scale}")

for 2d unet model, we config with num_attention_heads = 8

unet2d
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:80.0, module.heads:8
 module.scale: 0.11180339887498948
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:40.0, module.heads:8
 module.scale: 0.15811388300841897
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949
  module.inner_dim/module_heads:160.0, module.heads:8
 module.scale: 0.07905694150420949

for 3d unet models, we config with attention_head_dim=64, results are expected

unet3d
Loading pipeline components...:  20%|██████████████████▌                                                                          | 1/5 [00:00<00:01,  2.89it/s]/home/yiyi_huggingface_co/diffusers/src/diffusers/models/unets/unet_3d_blocks.py:355: 
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.75it/s]
  module.inner_dim/module_heads:64.0, module.heads:8
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:8
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:10
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:5
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125
  module.inner_dim/module_heads:64.0, module.heads:20
 module.scale: 0.125

@yiyixuxu yiyixuxu marked this pull request as draft February 6, 2024 09:33
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu marked this pull request as ready for review February 7, 2024 10:22
@DN6
Copy link
Collaborator

DN6 commented Feb 7, 2024

@yiyixuxu See my comment here regarding this
#6872 (comment)

I think we would want to configure everything to use num_attention_heads instead of attention_head_dim so that the 2D UNet and 3D UNet behave similarly.

Here's a draft PR with what I'm proposing
#6893

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 7, 2024

I think another option is:

patrickvonplaten

This comment was marked as outdated.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

github-actions bot commented Mar 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 7, 2024
@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Mar 7, 2024

closing in favor of #6893

@yiyixuxu yiyixuxu closed this Mar 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants