Skip to content

Commit c147a42

Browse files
yiyixuxuyiyixuxuDN6
authored andcommitted
correct attention_head_dim for JointTransformerBlock (#8608)
* add * update sd3 controlnet * Update src/diffusers/models/controlnet_sd3.py --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 08db291 commit c147a42

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
128128
query_dim=dim,
129129
cross_attention_dim=None,
130130
added_kv_proj_dim=dim,
131-
dim_head=attention_head_dim // num_attention_heads,
131+
dim_head=attention_head_dim,
132132
heads=num_attention_heads,
133-
out_dim=attention_head_dim,
133+
out_dim=dim,
134134
context_pre_only=context_pre_only,
135135
bias=True,
136136
processor=processor,

src/diffusers/models/controlnet_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
JointTransformerBlock(
8282
dim=self.inner_dim,
8383
num_attention_heads=num_attention_heads,
84-
attention_head_dim=self.inner_dim,
84+
attention_head_dim=self.config.attention_head_dim,
8585
context_pre_only=False,
8686
)
8787
for i in range(num_layers)

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797
JointTransformerBlock(
9898
dim=self.inner_dim,
9999
num_attention_heads=self.config.num_attention_heads,
100-
attention_head_dim=self.inner_dim,
100+
attention_head_dim=self.config.attention_head_dim,
101101
context_pre_only=i == num_layers - 1,
102102
)
103103
for i in range(self.config.num_layers)

0 commit comments

Comments
 (0)