Skip to content

Commit e9fee1a

Browse files
committed
No trace tensors
1 parent ea85c86 commit e9fee1a

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

src/diffusers/models/attention_processor.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ def __call__(
11411141
hidden_states = hidden_states = F.scaled_dot_product_attention(
11421142
query, key, value, dropout_p=0.0, is_causal=False
11431143
)
1144-
trace_tensor("attn_out", hidden_states[0,0,0,0])
1144+
#trace_tensor("attn_out", hidden_states[0,0,0,0])
11451145

11461146
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
11471147
hidden_states = hidden_states.to(query.dtype)
@@ -1152,7 +1152,7 @@ def __call__(
11521152
hidden_states[:, residual.shape[1] :],
11531153
)
11541154
hidden_states_cl = hidden_states.clone()
1155-
trace_tensor("attn_out", hidden_states_cl[0,0,0])
1155+
#trace_tensor("attn_out", hidden_states_cl[0,0,0])
11561156
# linear proj
11571157
hidden_states = attn.to_out[0](hidden_states_cl)
11581158
# dropout
@@ -1221,13 +1221,9 @@ def __call__(
12211221
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
12221222
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
12231223
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1224-
trace_tensor("query", query)
1225-
trace_tensor("key", key)
1226-
trace_tensor("value", value)
12271224
hidden_states = hidden_states = F.scaled_dot_product_attention(
12281225
query, key, value, dropout_p=0.0, is_causal=False
12291226
)
1230-
trace_tensor("attn_out", hidden_states[:,:,:50])
12311227

12321228
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
12331229
hidden_states = hidden_states.to(query.dtype)
@@ -1597,10 +1593,6 @@ def __call__(
15971593
hidden_states = F.scaled_dot_product_attention(
15981594
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
15991595
)
1600-
trace_tensor("query", query)
1601-
trace_tensor("key", key)
1602-
trace_tensor("value", value)
1603-
trace_tensor("attn_out", hidden_states[:,:,:50])
16041596
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
16051597
hidden_states = hidden_states.to(query.dtype)
16061598

0 commit comments

Comments
 (0)