Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 backward #119

Open
wants to merge 8 commits into
base: main_perf
Choose a base branch
from
Open

fp8 backward #119

wants to merge 8 commits into from

Conversation

micmelesse
Copy link
Collaborator

@micmelesse micmelesse commented Jan 24, 2025

add fp8 backward

@micmelesse micmelesse changed the title add backward test case fp8 backward Jan 24, 2025
@micmelesse micmelesse force-pushed the micmelesse/fp8_bwd branch 4 times, most recently from 6b691eb to 297742b Compare February 3, 2025 09:24
@micmelesse micmelesse marked this pull request as ready for review February 4, 2025 13:37
Copy link

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm approving the PR because I can't see anything wrong with it. I just left some questions and cleanup suggestions.

@@ -43,6 +43,9 @@ python setup.py install
pytest tests/test_flash_attn_triton_amd.py
```

##### FP8
In our fork, we have modified the api to work with fp8. You provide tensors that are scaled to be in fp8 range and their associated descaling factors.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that "scaled fp8 tensors" is better than "tensors that are scaled to be in fp8 range". I don't know, maybe someone can interpret this as an arbitrary data type scaled to be in fp8 range.

Do you think it's worth mentioning the descaling factors' type in this README?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to add more info to the README. This was just the start.

descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_h)

# do is scaled in the fp8 range and o is in fp8 but should be the same scale as fp32
# TODO: descale do so that we can use it as fp32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this TODO comment is deprecated since it's already done.

ds = dscores_scaled * sm_scale
ds = tl.where(p_mask, ds, 0.0)

# print("p:", p)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we clean up these print statements before merging?

# print("delta_i:", delta_i)
# print("ds:", ds) # NOTE:is almost the same between fp8 and fp16
descale_ds = descale_p
ds_fp8_scaled = ds * (1.0/ descale_ds)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to compute ds_fp8_scaled for fp8 kernel, otherwise it's unused. What do you think of computing ds_fp8_scaled inside a if IS_FP8: statement?

@@ -553,6 +636,14 @@ def attention_prefill_backward_triton_impl(
print("use_exp2:", use_exp2)
print("sequence_parallel:", sequence_parallel)

is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
if is_fp8:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this empty if is_fp8: statement can be removed. Do we need to print or debug anything inside it?

flash_attn/flash_attn_triton_amd/test.py Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Show resolved Hide resolved
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants