Skip to content

Commit

Permalink
skip mqa cases
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 1, 2025
1 parent 1ff68ae commit 6b691eb
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,7 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
descale_do=descale_do
)

# fp8 backward pass
dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)

# compare
# compare forward
if DEBUG:
print()
print("Compare fp8 against ref")
Expand All @@ -880,6 +877,14 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_fp16 is not None else None)
torch.testing.assert_close(S_dmask_fp16.to(torch.float32) if S_dmask_fp16 is not None else None, S_dmask_fp8.to(torch.float32) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)

if HQ // HK != 1:
print("Skipping backward for MQA/GQA cases because atomic_add doesnot support fp8")
return

# fp8 backward pass
dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)

# compare backward
if DEBUG:
print("dv_fp16:", dv_fp16, dv_fp16.shape)
print("dv_fp8:", dv_fp8, dv_fp8.shape)
Expand Down

0 comments on commit 6b691eb

Please sign in to comment.