Skip to content

Commit ea85c86

Browse files
committed
SD3 bringup
1 parent 4a2f260 commit ea85c86

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

src/diffusers/models/attention_processor.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2626
from ..utils.torch_utils import maybe_allow_in_graph
2727
from .lora import LoRALinearLayer
28+
from shark_turbine.ops.iree import trace_tensor
2829

2930

3031
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -816,6 +817,8 @@ def __call__(
816817
value = attn.head_to_batch_dim(value)
817818

818819
attention_probs = attn.get_attention_scores(query, key, attention_mask)
820+
821+
819822
hidden_states = torch.bmm(attention_probs, value)
820823
hidden_states = attn.batch_to_head_dim(hidden_states)
821824

@@ -922,6 +925,7 @@ def __call__(
922925
value = attn.head_to_batch_dim(value)
923926

924927
attention_probs = attn.get_attention_scores(query, key, attention_mask)
928+
925929
hidden_states = torch.bmm(attention_probs, value)
926930
hidden_states = attn.batch_to_head_dim(hidden_states)
927931

@@ -1131,10 +1135,14 @@ def __call__(
11311135
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
11321136
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
11331137
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1134-
1138+
# trace_tensor("query", query[0,0,0])
1139+
# trace_tensor("key", key[0,0,0])
1140+
# trace_tensor("value", value[0,0,0])
11351141
hidden_states = hidden_states = F.scaled_dot_product_attention(
11361142
query, key, value, dropout_p=0.0, is_causal=False
11371143
)
1144+
trace_tensor("attn_out", hidden_states[0,0,0,0])
1145+
11381146
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
11391147
hidden_states = hidden_states.to(query.dtype)
11401148

@@ -1143,9 +1151,10 @@ def __call__(
11431151
hidden_states[:, : residual.shape[1]],
11441152
hidden_states[:, residual.shape[1] :],
11451153
)
1146-
1154+
hidden_states_cl = hidden_states.clone()
1155+
trace_tensor("attn_out", hidden_states_cl[0,0,0])
11471156
# linear proj
1148-
hidden_states = attn.to_out[0](hidden_states)
1157+
hidden_states = attn.to_out[0](hidden_states_cl)
11491158
# dropout
11501159
hidden_states = attn.to_out[1](hidden_states)
11511160
if not attn.context_pre_only:
@@ -1212,10 +1221,14 @@ def __call__(
12121221
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
12131222
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
12141223
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1215-
1224+
trace_tensor("query", query)
1225+
trace_tensor("key", key)
1226+
trace_tensor("value", value)
12161227
hidden_states = hidden_states = F.scaled_dot_product_attention(
12171228
query, key, value, dropout_p=0.0, is_causal=False
12181229
)
1230+
trace_tensor("attn_out", hidden_states[:,:,:50])
1231+
12191232
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
12201233
hidden_states = hidden_states.to(query.dtype)
12211234

@@ -1584,7 +1597,10 @@ def __call__(
15841597
hidden_states = F.scaled_dot_product_attention(
15851598
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
15861599
)
1587-
1600+
trace_tensor("query", query)
1601+
trace_tensor("key", key)
1602+
trace_tensor("value", value)
1603+
trace_tensor("attn_out", hidden_states[:,:,:50])
15881604
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
15891605
hidden_states = hidden_states.to(query.dtype)
15901606

@@ -1778,6 +1794,7 @@ def __call__(
17781794
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
17791795
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
17801796

1797+
17811798
# the output of sdp = (batch, num_heads, seq_len, head_dim)
17821799
# TODO: add support for attn.scale when we move to Torch 2.1
17831800
hidden_states = F.scaled_dot_product_attention(

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -247,17 +247,21 @@ def step(
247247

248248
sigma = self.sigmas[self.step_index]
249249

250-
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
251-
250+
condition = s_tmin <= sigma
251+
condition1 = sigma <= s_tmax
252+
gamma = torch.where(condition & condition1,
253+
torch.minimum(torch.tensor(s_churn / (len(self.sigmas) - 1)), torch.tensor(2**0.5 - 1)),
254+
torch.tensor(0.0))
255+
252256
noise = randn_tensor(
253257
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
254258
)
255259

256260
eps = noise * s_noise
257261
sigma_hat = sigma * (gamma + 1)
258262

259-
if gamma > 0:
260-
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
263+
sample = torch.where(gamma > 0, sample + eps * (sigma_hat**2 - sigma**2) ** 0.5, sample)
264+
261265

262266
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
263267
# NOTE: "original_sample" should not be an expected prediction_type but is left in for

0 commit comments

Comments
 (0)