feat(ep): add AccumNum=9 vec8 intranode combine for shared-expert fusion#396
Open
rbrugaro-amd wants to merge 1 commit into
Open
feat(ep): add AccumNum=9 vec8 intranode combine for shared-expert fusion#396rbrugaro-amd wants to merge 1 commit into
rbrugaro-amd wants to merge 1 commit into
Conversation
Shared-expert fusion routes every token through 9 sources (8 routed + 1
fused shared expert, topk 8->9), but the fast weightless fp8_blockwise
vec8 combine path only had an AccumNum=8 specialization, so topk=9 fell
back to the generic combine kernel.
Generalize the WarpAccumFp8Dequant* vec8 helpers and the
EpCombineIntraNodeKernel template on an AccumNum parameter (default 8),
add a WRAP_BOOL7 wrapper macro, and register block128/block256
_vec8_top9 (AccumNum=9) symbols. The launch gate (launch.cpp and the
mirrored Python path in dispatch_combine.py) now accepts
numExpertPerToken in {8,9} and selects the _top9 symbol when topk==9.
Add a weightless intranode combine test that exercises both topk=8 and
topk=9 fp8_blockwise paths and asserts the specialized kernel was
actually selected (guards against silent fallback to the generic path,
which would otherwise produce identical weightless results).
Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR extends the fast weightless fp8_blockwise vec8 intranode combine path to support top-k = 9 (AccumNum=9) for shared-expert fusion, by adding new kernel instantiations and relaxing the kernel-selection gate in both the C++ and Python dispatch paths. It also adds a Python test that asserts the specialized kernel was selected to guard against silent fallback to the generic combine.
Changes:
- Add AccumNum=9 (
_vec8_top9) kernel symbols and generalize vec8 accumulation helpers to accept anAccumNumtemplate parameter. - Relax C++/Python launch gating to select the vec8 specialization for
numExpertPerToken in {8, 9}and choose_top9whentopk == 9. - Add a weightless intranode combine test that validates the specialized kernel was actually selected.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
include/mori/core/transport/p2p/device_primitives.hpp |
Generalizes vec8 fp8 dequant+accum helpers to allow AccumNum=9. |
src/ops/dispatch_combine/intranode.hpp |
Threads a new Vec8AccumNum template parameter into the vec8 helper calls and kernel template. |
src/ops/kernels/ep_common.hip |
Adds WRAP_BOOL7 wrapper macro for kernels with 7 non-type template args. |
src/ops/kernels/ep_intranode.hip |
Registers new top9 (AccumNum=9) weightless vec8 combine symbols. |
src/ops/dispatch_combine/launch.cpp |
Relaxes the C++ vec8 specialization gate to allow top-k 9 and selects _top9 symbols. |
python/mori/ops/dispatch_combine.py |
Mirrors the updated gate and records _last_combine_kernel_name for tests. |
tests/python/ops/dispatch_combine_test_utils.py |
Adds a weightless test option and kernel-name assertion hook. |
tests/python/ops/test_dispatch_combine_intranode.py |
Adds a new test covering weightless vec8 combine for topk 8 and 9. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+811
to
+813
| test_case.run_test_once( | ||
| op, test_data, check_results=check_results, weightless=weightless | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat(ep): add AccumNum=9 vec8 intranode combine for shared-expert fusion
Summary
fp8_blockwisevec8 intranode combine kernel so shared-expert fusion (topk 8→9) can use the fast path instead of falling back to the generic combine.EpCombineIntraNodeKerneltemplate on a newAccumNumtemplate parameter (default 8) — no behavior change for existing top-8 callers.launch.cppand the mirrored Python path indispatch_combine.py) from a hardcodednumExpertPerToken == 8to{8, 9}, selecting the_vec8_top9symbol when topk == 9.Why
Shared-expert fusion folds the always-on shared expert in as an extra grouped-GEMM expert and raises the effective top-k from 8 to 9 (8 routed + 1 fused shared). The fast weightless
fp8_blockwisevec8 combine path only had an AccumNum=8 specialization that unrolls over exactly 8 accumulation sources, sonumExpertPerToken == 9failed the gate by construction and dropped to the slower generic combine kernel.The combine call is already weightless in this path (the caller passes
weights=None), so the only blocker was the hardcoded top-k. The fix is a new AccumNum=9 kernel instantiation plus the gate relax — the accumulation math is identical to the existing weightless path, just unrolled over 9 sources instead of 8.What changed
include/mori/core/transport/p2p/device_primitives.hppWarpAccumFp8DequantFullBlockVec8Top8/...SegmentBlockVec8Top8/...SegmentScalarTop8gain anint AccumNum = 8template param (wasconstexpr int AccumNum = 8).src/ops/dispatch_combine/intranode.hppEpCombineIntraNodeKernelbody + global gainint Vec8AccumNum = 8, threaded into the three vec8 helper calls.src/ops/kernels/ep_common.hipWRAP_BOOL7wrapper macro (7 non-type template args).src/ops/kernels/ep_intranode.hip..._noweight_block128_vec8_top9and..._block256_vec8_top9(AccumNum=9) symbols.src/ops/dispatch_combine/launch.cppnumExpertPerToken in {8,9}; selects_top9symbol when topk==9.python/mori/ops/dispatch_combine.py_last_combine_kernel_namefor test introspection.tests/python/ops/dispatch_combine_test_utils.pyrun_test_once/run_ep_dispatch_combine_testgainweightless+expect_combine_kernel_substr; asserts the selected kernel.tests/python/ops/test_dispatch_combine_intranode.pytest_dispatch_combine_weightless_vec8(topk 8 and 9).Trace verification
Before:

After:

Testing
Validated in a fresh container on MI355X (gfx950) against current
main(post-#392, the PR that rewrote the combine body) — not just the fork base — to confirm the patch still applies and behaves correctly after the recent combine refactor.ep_intranode.hsacorebuild clean; bothnoweight_block128_vec8_top9andnoweight_block256_vec8_top9symbols present in the compiled HSACO.test_dispatch_combine_weightless_vec8): 2 passed (topk 8 + topk 9). The topk=9 case asserts the runtime selectednoweight_block128_vec8_top9, so a silent fallback to the generic path (which would produce identical weightless numbers) fails the test.test_dispatch_combine): 128 passed, 0 failed (16 fp8_blockwise + 112 none/fp8_direct_cast; remaining parametrizations skip on existing constraints).CI note: the
intranode-testjob already runspytest tests/python/ops/test_dispatch_combine_intranode.py -v(timeout 360s). The new test adds ~21s.End-to-end impact (DeepSeek-R1 MXFP4, MI355X, EP=8 offline)
Measured in the MLPerf offline harness under shared-expert fusion (topk=9), same RNG seeds and test parameters in both runs (apples-to-apples). The runtime confirmed selecting
EpCombineIntraNodeKernel_bf16_nop2p_fp8bwq_noweight_block128_vec8_top9(topk=9, block_elems=128) on all 8 ranks. Both results VALID.vec8_top9(this PR)Accuracy is unchanged (accumulation math is identical to the generic weightless path).
Signed-off-by: Rita Brugarolas Brufau rita.brugarolasbrufau@amd.com