-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fp8 backward #119
base: main_perf
Are you sure you want to change the base?
fp8 backward #119
Conversation
6b691eb
to
297742b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm approving the PR because I can't see anything wrong with it. I just left some questions and cleanup suggestions.
@@ -43,6 +43,9 @@ python setup.py install | |||
pytest tests/test_flash_attn_triton_amd.py | |||
``` | |||
|
|||
##### FP8 | |||
In our fork, we have modified the api to work with fp8. You provide tensors that are scaled to be in fp8 range and their associated descaling factors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that "scaled fp8 tensors" is better than "tensors that are scaled to be in fp8 range". I don't know, maybe someone can interpret this as an arbitrary data type scaled to be in fp8 range.
Do you think it's worth mentioning the descaling factors' type in this README?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am going to add more info to the README. This was just the start.
descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_h) | ||
|
||
# do is scaled in the fp8 range and o is in fp8 but should be the same scale as fp32 | ||
# TODO: descale do so that we can use it as fp32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this TODO
comment is deprecated since it's already done.
ds = dscores_scaled * sm_scale | ||
ds = tl.where(p_mask, ds, 0.0) | ||
|
||
# print("p:", p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we clean up these print
statements before merging?
# print("delta_i:", delta_i) | ||
# print("ds:", ds) # NOTE:is almost the same between fp8 and fp16 | ||
descale_ds = descale_p | ||
ds_fp8_scaled = ds * (1.0/ descale_ds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need to compute ds_fp8_scaled
for fp8 kernel, otherwise it's unused. What do you think of computing ds_fp8_scaled
inside a if IS_FP8:
statement?
@@ -553,6 +636,14 @@ def attention_prefill_backward_triton_impl( | |||
print("use_exp2:", use_exp2) | |||
print("sequence_parallel:", sequence_parallel) | |||
|
|||
is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} | |||
if is_fp8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this empty if is_fp8:
statement can be removed. Do we need to print or debug anything inside it?
and ds This is a combination of 9 commits. Enable BWD fp8 This is a combination of 12 commits. add backward test case save clean up disable ci lse is good dv matches reduce diff use do fp8 for dv kinda working group size is a constexpr clean up a bit everything except mqa/gqa works skip mqa cases 20 cases have nan on dropout save what you have disable tests failing enable tests per block descale_p and descale_ds use max(abs(()) clean up tests a bit more
7d0277c
to
4cd9e2a
Compare
add fp8 backward