-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
Bugfix: Cutlass FP8 FusedMoE #27255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Bugfix: Cutlass FP8 FusedMoE #27255
Conversation
When running cutlass FusedMoE FP8 the scaling factors that are passed are None. This PR passes the correct scaling factors and enables the relevant test. Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a bug in the Cutlass FP8 FusedMoE implementation by passing the necessary scaling factors. The changes are logical and enabling the previously skipped test test_flashinfer_cutlass_moe_fp8_no_graph
validates the fix. However, I've identified a critical risk of a division-by-zero error in the calculation of the a2_gscale
factor, which should be addressed to ensure numerical stability.
a1_scale=layer.w13_input_scale, | ||
a1_gscale=layer.w13_input_scale, | ||
a2_scale=layer.w2_input_scale, | ||
a2_gscale=1.0 / layer.w2_input_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation 1.0 / layer.w2_input_scale
introduces a risk of a division-by-zero error if layer.w2_input_scale
is zero. Although scales are typically positive, adding a small epsilon to the denominator is a crucial safeguard for numerical stability.
a2_gscale=1.0 / layer.w2_input_scale, | |
a2_gscale=1.0 / (layer.w2_input_scale + 1e-6), |
a1_scale=td.a1_scale, | ||
a1_gscale=td.a1_scale, | ||
a2_scale=td.a2_scale, | ||
a2_gscale=1.0 / td.a2_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To prevent potential division-by-zero errors and for consistency with the recommended fix in the main logic, it's safer to add a small epsilon to the denominator here. While td.a2_scale
is currently 1.0 in this test, this change improves the robustness of the test suite against future modifications.
a2_gscale=1.0 / td.a2_scale, | |
a2_gscale=1.0 / (td.a2_scale + 1e-6), |
Purpose
When running cutlass FusedMoE FP8 the scaling factors that are passed are None. This PR passes the correct scaling factors and enables the relevant test.
Test Plan
Enabled previously disabled
test_flashinfer_cutlass_moe_fp8_no_graph
.