Skip to content

Commit

Permalink
Enable BWD fp8
Browse files Browse the repository at this point in the history
This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases
  • Loading branch information
micmelesse committed Feb 3, 2025
1 parent 3e5027c commit 297742b
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 111 deletions.
30 changes: 24 additions & 6 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
descale_do=None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -301,6 +306,11 @@ def _flash_attn_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_p,
descale_do
)
return softmax_d

Expand Down Expand Up @@ -823,7 +833,8 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -849,7 +860,7 @@ def forward(
descale_v=descale_v,
descale_p=descale_p,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
Expand All @@ -862,7 +873,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(3)
dout_padded = dout
Expand All @@ -887,11 +898,16 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand Down Expand Up @@ -1151,7 +1167,8 @@ def flash_attn_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_p=None,
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1216,7 +1233,8 @@ def flash_attn_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
)


Expand Down
Loading

0 comments on commit 297742b

Please sign in to comment.