-
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Enable FP16 and BF16 CUTLASS MoE kernels #15932
base: main
Are you sure you want to change the base?
[Kernel] Enable FP16 and BF16 CUTLASS MoE kernels #15932
Conversation
Signed-off-by: ElizaWszola <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment that using the name fp16 for operations that handle both fp16 and bf16 is confusing and we should either pick a more general name (16bit?), or better: append fp8 to the names of ops that handle fp8 and remove fp16
from names altogether
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
super().process_weights_after_loading(layer) | ||
|
||
# TODO half() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the TODO? Resolve before landing?
This pull request has merge conflicts that must be resolved before it can be |
csrc/cutlass_moe/moe_mm_c3x.cu
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the 16-bit configs are the same as the fp8 configs -- these need to be re-tuned for the fp16/bf16 case
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
def run_8_bit(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, | ||
w2_q: torch.Tensor, w1_scale: torch.Tensor, | ||
w2_scale: torch.Tensor, topk_weights: torch.Tensor, | ||
topk_ids: torch.Tensor, ab_strides1: torch.Tensor, | ||
c_strides1: torch.Tensor, ab_strides2: torch.Tensor, | ||
c_strides2: torch.Tensor): | ||
with set_current_vllm_config( | ||
VllmConfig(parallel_config=ParallelConfig( | ||
pipeline_parallel_size=1))): | ||
return cutlass_moe_fp8(a, | ||
w1_q, | ||
w2_q, | ||
w1_scale, | ||
w2_scale, | ||
topk_weights, | ||
topk_ids, | ||
ab_strides1, | ||
c_strides1, | ||
ab_strides2, | ||
c_strides2, | ||
a1_scale=a_scale) | ||
|
||
|
||
@pytest.mark.parametrize("m", [2, 64, 224]) | ||
@pytest.mark.parametrize("n", [1024, 3072]) | ||
@pytest.mark.parametrize("k", [1024, 1536]) | ||
return cutlass_moe(a, | ||
w1_q, | ||
w2_q, | ||
topk_weights, | ||
topk_ids, | ||
ab_strides1, | ||
c_strides1, | ||
ab_strides2, | ||
c_strides2, | ||
w1_scale=w1_scale, | ||
w2_scale=w2_scale, | ||
a1_scale=a_scale) | ||
|
||
|
||
def run_16_bit(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, | ||
topk_weights: torch.Tensor, topk_ids: torch.Tensor, | ||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor, | ||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor): | ||
with set_current_vllm_config( | ||
VllmConfig(parallel_config=ParallelConfig( | ||
pipeline_parallel_size=1))): | ||
return cutlass_moe(a, w1, w2, topk_weights, topk_ids, ab_strides1, | ||
c_strides1, ab_strides2, c_strides2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Could these be combined with the scales made optional and defaulted to None for the fp16 case? I don't have strong feelings about this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, probably
print(triton_output) | ||
print(cutlass_output) | ||
print("*") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Do we need this prints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tolerances in the tests are a bit high, so I use these prints to examine manually how off the values are if I'm close to the treshold
print(triton_output) | ||
print(cutlass_output) | ||
print("*") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
||
if capability_tuple is not None: | ||
capability = capability_tuple.to_int() | ||
arch_supported = (capability == required_capability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be a straight equality check? could it be >=
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These CUTLASS kernels aren't forward-compatible unfortunately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't look over the cutlass parts too closely but the python bits lgtm.
It appears we need to tune these before landing.
main
this pr
|
using Cutlass3xGemmM16 = typename sm90_16_bit_config_M16< | ||
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; | ||
using Cutlass3xGemmDefault = typename sm90_16_bit_config_default< | ||
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ElizaWszola - the 8 bit version uses ScaledEpilogueArray
but this uses TrivialEpilogue
. Is there a reason why ? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ! I see that it is for dequantization 👍
Implement BF16 and FP16 weight support in CUTLASS MoE kernels. Tested with
and