Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Enable FP16 and BF16 CUTLASS MoE kernels #15932

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Apr 2, 2025

Implement BF16 and FP16 weight support in CUTLASS MoE kernels. Tested with

llm = LLM("mistralai/Mixtral-8x7B-Instruct-v0.1",
          tensor_parallel_size=2,
)

and

llm = LLM("mistralai/Mixtral-8x7B-Instruct-v0.1",
          tensor_parallel_size=2,
          dtype=torch.float16,
)

Copy link

github-actions bot commented Apr 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@tlrmchlsmth tlrmchlsmth changed the title [WIP][Kernel] Enable BF16 weights in CUTLASS MoE [WIP][Kernel] Enable FP16 and BF16 CUTLASS MoE kernels Apr 2, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment that using the name fp16 for operations that handle both fp16 and bf16 is confusing and we should either pick a more general name (16bit?), or better: append fp8 to the names of ops that handle fp8 and remove fp16 from names altogether

@mergify mergify bot added the ci/build label Apr 3, 2025
@ElizaWszola ElizaWszola marked this pull request as ready for review April 3, 2025 14:21
@ElizaWszola ElizaWszola changed the title [WIP][Kernel] Enable FP16 and BF16 CUTLASS MoE kernels [Kernel] Enable FP16 and BF16 CUTLASS MoE kernels Apr 3, 2025
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

# TODO half()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the TODO? Resolve before landing?

Copy link

mergify bot commented Apr 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 4, 2025
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the 16-bit configs are the same as the fp8 configs -- these need to be re-tuned for the fp16/bf16 case

Signed-off-by: Tyler Michael Smith <[email protected]>
@mergify mergify bot removed the needs-rebase label Apr 4, 2025
Signed-off-by: Tyler Michael Smith <[email protected]>
Comment on lines +31 to +63
def run_8_bit(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, ab_strides1: torch.Tensor,
c_strides1: torch.Tensor, ab_strides2: torch.Tensor,
c_strides2: torch.Tensor):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe_fp8(a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale)


@pytest.mark.parametrize("m", [2, 64, 224])
@pytest.mark.parametrize("n", [1024, 3072])
@pytest.mark.parametrize("k", [1024, 1536])
return cutlass_moe(a,
w1_q,
w2_q,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)


def run_16_bit(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe(a, w1, w2, topk_weights, topk_ids, ab_strides1,
c_strides1, ab_strides2, c_strides2)

Copy link
Contributor

@bnellnm bnellnm Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could these be combined with the scales made optional and defaulted to None for the fp16 case? I don't have strong feelings about this though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably

Comment on lines +313 to +315
print(triton_output)
print(cutlass_output)
print("*")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need this prints?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tolerances in the tests are a bit high, so I use these prints to examine manually how off the values are if I'm close to the treshold

Comment on lines +371 to +373
print(triton_output)
print(cutlass_output)
print("*")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


if capability_tuple is not None:
capability = capability_tuple.to_int()
arch_supported = (capability == required_capability
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be a straight equality check? could it be >=?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These CUTLASS kernels aren't forward-compatible unfortunately

Copy link
Contributor

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't look over the cutlass parts too closely but the python bits lgtm.

@tlrmchlsmth
Copy link
Collaborator

It appears we need to tune these before landing.

vllm serve mistralai/Mixtral-8x7B-Instruct-v0.1 --tensor_parallel_size=2 --max_model_len=4096 --port 8192 --disable-log-requests --no-enable-prefix-caching

python benchmarks/benchmark_serving.py --model mistralai/Mixtral-8x7B-Instruct-v0.1 --dataset-name random --random-input-len 1000 --random-output-len 100 --ignore-eos --port 8192 --request-rate 10

main

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  103.86
Total input tokens:                      1000000
Total generated tokens:                  100000
Request throughput (req/s):              9.63
Output token throughput (tok/s):         962.84
Total Token throughput (tok/s):          10591.26
---------------Time to First Token----------------
Mean TTFT (ms):                          81.42
Median TTFT (ms):                        74.82
P99 TTFT (ms):                           151.31
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.38
Median TPOT (ms):                        25.33
P99 TPOT (ms):                           29.51
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.38
Median ITL (ms):                         20.68
P99 ITL (ms):                            67.00
==================================================

this pr

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  104.13
Total input tokens:                      1000000
Total generated tokens:                  100000
Request throughput (req/s):              9.60
Output token throughput (tok/s):         960.35
Total Token throughput (tok/s):          10563.86
---------------Time to First Token----------------
Mean TTFT (ms):                          152.00
Median TTFT (ms):                        138.80
P99 TTFT (ms):                           357.98
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          38.58
Median TPOT (ms):                        38.34
P99 TPOT (ms):                           50.13
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.58
Median ITL (ms):                         21.81
P99 ITL (ms):                            161.14
==================================================

using Cutlass3xGemmM16 = typename sm90_16_bit_config_M16<
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
using Cutlass3xGemmDefault = typename sm90_16_bit_config_default<
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ElizaWszola - the 8 bit version uses ScaledEpilogueArray but this uses TrivialEpilogue . Is there a reason why ? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ! I see that it is for dequantization 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants