Skip to content

feat(ep): add AccumNum=9 vec8 intranode combine for shared-expert fusion#396

Open
rbrugaro-amd wants to merge 1 commit into
ROCm:mainfrom
rbrugaro-amd:rbrugaro/ep-combine-vec8-top9-shared-fusion
Open

feat(ep): add AccumNum=9 vec8 intranode combine for shared-expert fusion#396
rbrugaro-amd wants to merge 1 commit into
ROCm:mainfrom
rbrugaro-amd:rbrugaro/ep-combine-vec8-top9-shared-fusion

Conversation

@rbrugaro-amd

Copy link
Copy Markdown

feat(ep): add AccumNum=9 vec8 intranode combine for shared-expert fusion

Summary

  • Adds an AccumNum=9 specialization of the fast weightless fp8_blockwise vec8 intranode combine kernel so shared-expert fusion (topk 8→9) can use the fast path instead of falling back to the generic combine.
  • Generalizes the existing AccumNum=8 vec8 device helpers + the EpCombineIntraNodeKernel template on a new AccumNum template parameter (default 8) — no behavior change for existing top-8 callers.
  • Relaxes the launch gate (C++ launch.cpp and the mirrored Python path in dispatch_combine.py) from a hardcoded numExpertPerToken == 8 to {8, 9}, selecting the _vec8_top9 symbol when topk == 9.
  • Adds a weightless intranode combine test covering topk=8 and topk=9 that asserts the specialized kernel was actually selected (guards against silent fallback).

Why

Shared-expert fusion folds the always-on shared expert in as an extra grouped-GEMM expert and raises the effective top-k from 8 to 9 (8 routed + 1 fused shared). The fast weightless fp8_blockwise vec8 combine path only had an AccumNum=8 specialization that unrolls over exactly 8 accumulation sources, so numExpertPerToken == 9 failed the gate by construction and dropped to the slower generic combine kernel.

The combine call is already weightless in this path (the caller passes weights=None), so the only blocker was the hardcoded top-k. The fix is a new AccumNum=9 kernel instantiation plus the gate relax — the accumulation math is identical to the existing weightless path, just unrolled over 9 sources instead of 8.

What changed

File Change
include/mori/core/transport/p2p/device_primitives.hpp WarpAccumFp8DequantFullBlockVec8Top8 / ...SegmentBlockVec8Top8 / ...SegmentScalarTop8 gain an int AccumNum = 8 template param (was constexpr int AccumNum = 8).
src/ops/dispatch_combine/intranode.hpp EpCombineIntraNodeKernel body + global gain int Vec8AccumNum = 8, threaded into the three vec8 helper calls.
src/ops/kernels/ep_common.hip New WRAP_BOOL7 wrapper macro (7 non-type template args).
src/ops/kernels/ep_intranode.hip Register ..._noweight_block128_vec8_top9 and ..._block256_vec8_top9 (AccumNum=9) symbols.
src/ops/dispatch_combine/launch.cpp Gate accepts numExpertPerToken in {8,9}; selects _top9 symbol when topk==9.
python/mori/ops/dispatch_combine.py Mirror of the launch gate; records _last_combine_kernel_name for test introspection.
tests/python/ops/dispatch_combine_test_utils.py run_test_once / run_ep_dispatch_combine_test gain weightless + expect_combine_kernel_substr; asserts the selected kernel.
tests/python/ops/test_dispatch_combine_intranode.py New test_dispatch_combine_weightless_vec8 (topk 8 and 9).

Trace verification

Before:
image

After:
image

Testing

Validated in a fresh container on MI355X (gfx950) against current main (post-#392, the PR that rewrote the combine body) — not just the fork base — to confirm the patch still applies and behaves correctly after the recent combine refactor.

  • Build: C++ extension + JIT ep_intranode.hsaco rebuild clean; both noweight_block128_vec8_top9 and noweight_block256_vec8_top9 symbols present in the compiled HSACO.
  • New test (test_dispatch_combine_weightless_vec8): 2 passed (topk 8 + topk 9). The topk=9 case asserts the runtime selected noweight_block128_vec8_top9, so a silent fallback to the generic path (which would produce identical weightless numbers) fails the test.
  • Top-8 regression (test_dispatch_combine): 128 passed, 0 failed (16 fp8_blockwise + 112 none/fp8_direct_cast; remaining parametrizations skip on existing constraints).

CI note: the intranode-test job already runs pytest tests/python/ops/test_dispatch_combine_intranode.py -v (timeout 360s). The new test adds ~21s.

End-to-end impact (DeepSeek-R1 MXFP4, MI355X, EP=8 offline)

Measured in the MLPerf offline harness under shared-expert fusion (topk=9), same RNG seeds and test parameters in both runs (apples-to-apples). The runtime confirmed selecting EpCombineIntraNodeKernel_bf16_nop2p_fp8bwq_noweight_block128_vec8_top9 (topk=9, block_elems=128) on all 8 ranks. Both results VALID.

Combine path Tokens/s Samples/s
generic (before) 31579.1 8.37529
vec8_top9 (this PR) 32166.2 8.45422
Δ +587.1 (+1.86%) +0.94%

Accuracy is unchanged (accumulation math is identical to the generic weightless path).


Signed-off-by: Rita Brugarolas Brufau rita.brugarolasbrufau@amd.com

Shared-expert fusion routes every token through 9 sources (8 routed + 1
fused shared expert, topk 8->9), but the fast weightless fp8_blockwise
vec8 combine path only had an AccumNum=8 specialization, so topk=9 fell
back to the generic combine kernel.

Generalize the WarpAccumFp8Dequant* vec8 helpers and the
EpCombineIntraNodeKernel template on an AccumNum parameter (default 8),
add a WRAP_BOOL7 wrapper macro, and register block128/block256
_vec8_top9 (AccumNum=9) symbols. The launch gate (launch.cpp and the
mirrored Python path in dispatch_combine.py) now accepts
numExpertPerToken in {8,9} and selects the _top9 symbol when topk==9.

Add a weightless intranode combine test that exercises both topk=8 and
topk=9 fp8_blockwise paths and asserts the specialized kernel was
actually selected (guards against silent fallback to the generic path,
which would otherwise produce identical weightless results).

Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the fast weightless fp8_blockwise vec8 intranode combine path to support top-k = 9 (AccumNum=9) for shared-expert fusion, by adding new kernel instantiations and relaxing the kernel-selection gate in both the C++ and Python dispatch paths. It also adds a Python test that asserts the specialized kernel was selected to guard against silent fallback to the generic combine.

Changes:

  • Add AccumNum=9 (_vec8_top9) kernel symbols and generalize vec8 accumulation helpers to accept an AccumNum template parameter.
  • Relax C++/Python launch gating to select the vec8 specialization for numExpertPerToken in {8, 9} and choose _top9 when topk == 9.
  • Add a weightless intranode combine test that validates the specialized kernel was actually selected.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
include/mori/core/transport/p2p/device_primitives.hpp Generalizes vec8 fp8 dequant+accum helpers to allow AccumNum=9.
src/ops/dispatch_combine/intranode.hpp Threads a new Vec8AccumNum template parameter into the vec8 helper calls and kernel template.
src/ops/kernels/ep_common.hip Adds WRAP_BOOL7 wrapper macro for kernels with 7 non-type template args.
src/ops/kernels/ep_intranode.hip Registers new top9 (AccumNum=9) weightless vec8 combine symbols.
src/ops/dispatch_combine/launch.cpp Relaxes the C++ vec8 specialization gate to allow top-k 9 and selects _top9 symbols.
python/mori/ops/dispatch_combine.py Mirrors the updated gate and records _last_combine_kernel_name for tests.
tests/python/ops/dispatch_combine_test_utils.py Adds a weightless test option and kernel-name assertion hook.
tests/python/ops/test_dispatch_combine_intranode.py Adds a new test covering weightless vec8 combine for topk 8 and 9.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +811 to +813
test_case.run_test_once(
op, test_data, check_results=check_results, weightless=weightless
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants