Skip to content

Conversation

@xiangze-arm
Copy link

@xiangze-arm xiangze-arm commented Oct 21, 2025

Description

  • Add oneDNN/ACL matmul path for AArch64 in fused moe
  • Use silu_and_mul Op

Test Plan

  • Tested locally with Qwen3-30B-A3B model.
  • Unit test pytest tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic

Test Result

  • End to end test result is ok.
  • Unit test passes.

Performance

With this PR, MoE can go into onednn_mm path on AArch64 CPU. On 32 Neoverse-N2 cores, this PR gets about 1.6x throughput compared with current default path.

Bench command:

vllm bench throughput --dtype=bfloat16 --num-prompts 64 --seed 0 --dataset-name sharegpt --max-model-len=2048 --input_len=64 --model Qwen3-30B-A3B

- Add oneDNN/ACL matmul path for AArch64
- Use silu_and_mul Op

Signed-off-by: Zhang Xiangze <[email protected]>
@xiangze-arm xiangze-arm requested a review from mgoin as a code owner October 21, 2025 03:47
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves CPU performance for fused MoE layers by introducing a oneDNN/ACL matmul path for AArch64 architectures and leveraging a custom silu_and_mul operator. The changes are well-implemented, preparing oneDNN matmul handlers during initialization and using them in the forward pass. While this is a solid performance enhancement, I've identified a high-severity issue concerning model serialization. The created oneDNN handlers are not serializable, and attaching them to the model can lead to crashes upon saving and reloading the model.

Comment on lines +256 to +267
gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32)
layer.gate_up_linear.append(
lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm(
handle, x, bias
)
)
down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32)
layer.down_linear.append(
lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm(
handle, x, bias
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CPUDNNLGEMMHandler objects created by ops.create_onednn_mm contain pointers to C++ state and are not serializable. Storing lambdas that capture these handlers in layer.gate_up_linear and layer.down_linear will cause issues if the model is serialized (e.g., with pickle or torch.save). Upon deserialization, the handler pointers will be invalid, which can lead to segmentation faults when the model is used or garbage collected.

To prevent this, the CPUDNNLGEMMHandler class should be made non-picklable by implementing __getstate__ to raise an exception. Since that class is not in this file, an alternative is to avoid storing these handlers on the torch.nn.Module instance if model serialization is a possibility. If serialization is not a supported use case for CPU-based models, this might be acceptable, but it's a significant risk.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: Zhang Xiangze <[email protected]>
@xiangze-arm
Copy link
Author

I saw vllm crash when running unit test locally with pytest tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic

It seems the issue is not directly related with this PR. But the moe test exposed an issue in onednn MatMul cache.

The issue: ClassMatmulCacheKey in MatMulPrimitiveHandler does not consider different data type. Two onednn_mm with same weight size but different data types wrongly share the same cache key. moe test tests bf16 and fp32 with the same weight size, causes crash.

The tests passed when I tested bf16 and fp32 separately.

@xiangze-arm
Copy link
Author

@mgoin @bigPYJ1151 Can you help to review this PR?

cc @fadara01 @nikhil-arm

@fadara01
Copy link
Contributor

Hi @xiangze-arm - can you test again with #27472?

class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add the perf implications of this PR to the description?

@xiangze-arm
Copy link
Author

I have tested this PR with #27472 fix. The unit test passed without crash.

PR description is also updated with performance result (1.6x throughput improvement).

d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
output = torch.empty_like(x[..., : x.shape[-1] // 2])
torch.ops._C.silu_and_mul(output, x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a similar comment on your 4bit PR
what happens if this assert fails

TORCH_CHECK(d % VEC_ELEM_NUM == 0);

Copy link
Contributor

@fadara01 fadara01 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we crash because the activation kernel does not have a tail loop.
We should address this, but i don't think this particular concern should block this PR.

@fadara01
Copy link
Contributor

@bigPYJ1151 could you have a look at this please?

d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
output = torch.empty_like(x[..., : x.shape[-1] // 2])
torch.ops._C.silu_and_mul(output, x)
Copy link
Contributor

@fadara01 fadara01 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change needed here? is this faster than doing pytorch's silu and then pytorch mul?
I'm asking because exp() (needed to implement silu) is not really vectorized for Arm in vllm (we just call std::exp on each element in the vector)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants