Skip to content

Commit

Permalink
Enable BWD fp8 with per block scale factors for p
Browse files Browse the repository at this point in the history
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more
  • Loading branch information
micmelesse committed Feb 4, 2025
1 parent 929f0e8 commit 4cd9e2a
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 210 deletions.
60 changes: 48 additions & 12 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
descale_do=None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -301,6 +306,11 @@ def _flash_attn_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_p,
descale_do
)
return softmax_d

Expand Down Expand Up @@ -369,6 +379,11 @@ def _flash_attn_varlen_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
descale_do=None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -402,6 +417,11 @@ def _flash_attn_varlen_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_p,
descale_do
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -823,7 +843,8 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -849,7 +870,7 @@ def forward(
descale_v=descale_v,
descale_p=descale_p,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
Expand All @@ -862,7 +883,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(3)
dout_padded = dout
Expand All @@ -887,11 +908,16 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand All @@ -917,7 +943,8 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down Expand Up @@ -949,7 +976,7 @@ def forward(
descale_p=descale_p
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
Expand All @@ -965,7 +992,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(2)
dout_padded = dout
Expand Down Expand Up @@ -994,11 +1021,16 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -1151,7 +1183,8 @@ def flash_attn_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_p=None,
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1216,7 +1249,8 @@ def flash_attn_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
)


Expand Down Expand Up @@ -1396,7 +1430,8 @@ def flash_attn_varlen_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_p=None,
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1473,7 +1508,8 @@ def flash_attn_varlen_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
)


Expand Down
3 changes: 3 additions & 0 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ python setup.py install
pytest tests/test_flash_attn_triton_amd.py
```

##### FP8
In our fork, we have modified the api to work with fp8. You provide tensors that are scaled to be in fp8 range and their associated descaling factors.

##### Credits
AMD Triton kernels team

Expand Down
Loading

0 comments on commit 4cd9e2a

Please sign in to comment.