Skip to content

Commit 900cf7d

Browse files
committed
fix a dtype issue when evaluating the sana transformer with a float16 autocast context
1 parent 4b55713 commit 900cf7d

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5896,9 +5896,10 @@ def __call__(
58965896

58975897
query, key, value = query.float(), key.float(), value.float()
58985898

5899-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
5900-
scores = torch.matmul(value, key)
5901-
hidden_states = torch.matmul(scores, query)
5899+
with torch.autocast(device_type=hidden_states.device.type, enabled=False):
5900+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
5901+
scores = torch.matmul(value, key)
5902+
hidden_states = torch.matmul(scores, query)
59025903

59035904
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
59045905
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)

0 commit comments

Comments
 (0)