Skip to content

[Refactor] How attention is set in 3D UNet blocks #6893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/diffusers/models/unets/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ def __init__(
)
attentions.append(
Transformer2DModel(
out_channels // num_attention_heads,
num_attention_heads,
out_channels // num_attention_heads,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CrossAttnDownBlock3D is part of our public API and this is a breaking change, no?

Copy link
Collaborator Author

@DN6 DN6 Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, why isn't it backwards compatible? None of the args in the class init are being changed right?

The 3D blocks and the 3D UNet are only used with the Text to Video Synth and I2VGenXL model in the library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with get_up_block and get_down_block. We're only changing the number being passed in to the num_attention_heads argument.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the meaning of an argument is breaking, no?

for example,

CrossAttnDownBlock3D(... num_attention_heads = 64) is currently expected to create attentions with head_dim=64; with this code change, it will create attentions with 64 heads instead

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 3D blocks and the 3D UNet are only used with the Text to Video Synth and I2VGenXL model in the library.

yes but it is our public API and we have to assume it's been used outside of the library

Copy link
Collaborator Author

@DN6 DN6 Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be a relatively safe change. Searching Github public repos for an import of these blocks from the public API doesn't return any results. I actually don't think you can import it directly from diffusers

https://github.com/search?q=%22from+diffusers.models+import+CrossAttnDownBlock3D%22+language:Python+&type=code
https://github.com/search?q=%22from+diffusers+import+CrossAttnUpBlock3D%22+language:Python+&type=code
https://github.com/search?q=%22from+diffusers+import+UNetMidBlock3DCrossAttn%22+language:Python+&type=code

It looks like more often than not people redefine the blocks themselves
https://github.com/search?q=%22CrossAttnDownBlock3D%22+language:Python+&type=code&p=5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @pcuenca and @patrickvonplaten here.
I would like to hear your thoughts about when we can make breaking changes (other than v1.0.0).

Personally, I think we should only make breaking changes when we don't have another choice, or we know super confidently it is an edge case (e.g., if we just added these blocks yesterday, I would think it's ok to break here).
I think in this case,
(1) we do not have to make these changes: we want to make these changes to make our code more readable and easier for contributors to contribute, but it is not a must and this is not the only way to go
(2) we don't really have a way to find out about its usage outside github

also, I think a break change is somewhat more acceptable if we are able to throw an error. In this case, it will just be breaking silently so IMO it is worse

But I'm curious about your thoughts on this and I'm cool with it if you all feel strongly about making this change here :)

Copy link
Contributor

@patrickvonplaten patrickvonplaten Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python -c "from diffusers import CrossAttnDownBlock3D"

doesn't work, so strictly speaking CrossAttnDownBlock3D is not considered part of the public API. Also I don't think it's used that much so IMO it's ok to change it here (while making sure though that this might lead to breaking changes depending on how CrossAttnDownBlock3D is imported.

Comment on lines -500 to +501
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use key word arguments here when correcting it.

in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
Expand All @@ -510,8 +510,8 @@ def __init__(
)
temp_attentions.append(
TransformerTemporalModel(
out_channels // num_attention_heads,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
Expand Down Expand Up @@ -731,8 +731,8 @@ def __init__(
)
attentions.append(
Transformer2DModel(
out_channels // num_attention_heads,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
Expand All @@ -744,8 +744,8 @@ def __init__(
)
temp_attentions.append(
TransformerTemporalModel(
out_channels // num_attention_heads,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/unets/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,19 @@ def __init__(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to try to remove this statement

)

if isinstance(attention_head_dim, int):
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
Copy link
Contributor

@patrickvonplaten patrickvonplaten Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea.

Are we sure though that this is always correct? Does out_channels always represent the hidden_dim of the attention layer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the 3D UNets this is safe. There are a limited number of blocks used with this model CrossAttnDownBlock3D, CrossAttnUpBlock3DandUNetMidBlock3DCrossAttnand they all configurenum_attention_heads` based on the out_channels.

else:
num_attention_heads = [
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
]

# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also remove or update the comment above, assuming num_attention_heads is replicated in the hub configs. Do we know how many models like https://huggingface.co/ali-vilab/i2vgen-xl/blob/6c4e9e70bdcd36eb59d98d2b583adea0813ea8de/unet/config.json#L21 do we need to update?

Will we need to live forever with duplicated property names in the hub?

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just this one model and it was just out a few days ago
IMO it's an edge case we don't mind breaking - it will only affect people who want to use the local copy, no?


# Check inputs
if len(down_block_types) != len(up_block_types):
Expand Down