Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 backward #119

Open
wants to merge 8 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ permissions: read-all

jobs:
Integration-Tests-AMD:
if: false
runs-on: ${{ matrix.runner }}
strategy:
matrix:
Expand Down
60 changes: 48 additions & 12 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
descale_do=None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -301,6 +306,11 @@ def _flash_attn_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_p,
descale_do
)
return softmax_d

Expand Down Expand Up @@ -369,6 +379,11 @@ def _flash_attn_varlen_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
descale_do=None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -402,6 +417,11 @@ def _flash_attn_varlen_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_p,
descale_do
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -823,7 +843,8 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -849,7 +870,7 @@ def forward(
descale_v=descale_v,
descale_p=descale_p,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
Expand All @@ -862,7 +883,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(3)
dout_padded = dout
Expand All @@ -887,11 +908,16 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand All @@ -917,7 +943,8 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down Expand Up @@ -949,7 +976,7 @@ def forward(
descale_p=descale_p
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
Expand All @@ -965,7 +992,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_p, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(2)
dout_padded = dout
Expand Down Expand Up @@ -994,11 +1021,16 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -1151,7 +1183,8 @@ def flash_attn_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_p=None,
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1216,7 +1249,8 @@ def flash_attn_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
)


Expand Down Expand Up @@ -1396,7 +1430,8 @@ def flash_attn_varlen_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_p=None,
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1473,7 +1508,8 @@ def flash_attn_varlen_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_p,
descale_do
)


Expand Down
3 changes: 3 additions & 0 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ python setup.py install
pytest tests/test_flash_attn_triton_amd.py
```

##### FP8
In our fork, we have modified the api to work with fp8. You provide tensors that are scaled to be in fp8 range and their associated descaling factors.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that "scaled fp8 tensors" is better than "tensors that are scaled to be in fp8 range". I don't know, maybe someone can interpret this as an arbitrary data type scaled to be in fp8 range.

Do you think it's worth mentioning the descaling factors' type in this README?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to add more info to the README. This was just the start.


##### Credits
AMD Triton kernels team

Expand Down
Loading