Skip to content

Commit

Permalink
dk works
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 6, 2025
1 parent dd702d9 commit 1009f29
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions flash_attn/flash_attn_triton_amd/bwd_prefill_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def _bwd_preprocess(
Delta,
stride_ob, stride_oh, stride_om, stride_ok,
stride_deltab, stride_deltah, stride_deltam,
stride_descale_q_z,
cu_seqlens_q, max_seqlen_q,
Descale_do,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr,
ACTUAL_HEAD_DIM: tl.constexpr,
Expand Down Expand Up @@ -60,10 +62,16 @@ def _bwd_preprocess(
out_ptrs = O + offs_do
do_ptrs = DO + offs_do
# load
o = tl.load(out_ptrs, mask=mask_md, other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=mask_md, other=0.0).to(tl.float32)
o = tl.load(out_ptrs, mask=mask_md, other=0.0)
do = tl.load(do_ptrs, mask=mask_md, other=0.0)
# compute and write-back to delta
delta = tl.sum(o * do, axis=1)
if IS_FP8:
descale_do = tl.load(Descale_do + bid * stride_descale_q_z + hid)

# NOTE: do is scaled into the fp8 range and o is in fp8 but should be in the same scale as fp32
delta = tl.sum(o.to(tl.float32) * (do * descale_do).to(tl.float32), axis=1)
else:
delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1)
delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam
tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m)

Expand Down Expand Up @@ -1028,8 +1036,12 @@ def attention_prefill_backward_triton_split_impl(
IS_FP8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
if IS_FP8:
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
descale_q_stride_z = descale_q.stride(0)
descale_k_stride_z = descale_k.stride(0)
descale_v_stride_z = descale_v.stride(0)
else:
FP8_MAX = None
descale_q_stride_z = descale_k_stride_z = descale_v_stride_z = None

if dq is None:
dq = torch.zeros_like(q)
Expand Down Expand Up @@ -1084,7 +1096,9 @@ def attention_prefill_backward_triton_split_impl(
delta,
stride_ob, stride_oh, stride_om, stride_ok,
stride_deltab, stride_deltah, stride_deltam,
descale_q_stride_z,
cu_seqlens_q, max_seqlen_q,
descale_q,
BLOCK_M=PRE_BLOCK,
HEAD_DIM=HEAD_DIM,
ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM,
Expand Down

0 comments on commit 1009f29

Please sign in to comment.