You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
nvFuser currently achieves multi-GPU overlap by scheduling separate communication and compute kernels through Host IR, using stream parallelism. This works, but the overlap granularity is coarse, and with this approach communications necessarily represent kernel fusion boundaries.
This PR explores a different approach: GPU-initiated communication inside compute kernels. Instead of the host orchestrating separate comm and compute phases, a single CUDA kernel reads/writes remote GPU memory directly via symmetric memory pointers, interleaving data movement and computation at the thread level.
This is an experimental reference PR -- not intended for merge as-is, but as a self-contained, readable codebase for the team to study, benchmark, reproduce, and iterate on. The fused scalar kernels demonstrate the comm patterns and synchronization model. The two-kernel CUTLASS variants establish a performance ceiling. Closing the gap -- achieving CUTLASS-level compute inside a truly fused single kernel -- is the central open problem where we need the compute team's expertise.
For simplicity, we focus on "Allgather+Matmul" problem, on single H100 node NVLink.
Requires Hopper (SM90) for CUTLASS and multimem variants. Build with flag like TORCH_CUDA_ARCH_LIST="9.0a"
What this PR contains
A self-contained benchmark comparing 7 distributed matmul implementations for C[M,N] = A[M,K] x B[K,N] where A is row-sharded across ranks on axis M, B is replicated. All code lives in 3 test files:
test_multidevice_fused_remote_matmul_kernel.cu: CUDA kernels, CUTLASS wrapper, launchers
test_multidevice_fused_remote_matmul.cpp: Test harness, resource setup, timing, baselines
Small infrastructure change: SymmetricTensor::devicePeerPointers() added to symmetric_tensor.{h,cpp} -- lazily allocates a device-side pointer table for convenient kernel access to peer buffers.
Implementations
Baselines (separate allgather + matmul, no fusion):
baselineNcclAllgatherMatmul -- NCCL allgather to rebuild full A, then at::matmul. The standard-library reference.
baselineCudaAllgatherMatmul -- Same pattern using nvFuser's native backend for the allgather, using multicast NVLS
Truly fused kernels (comm + compute in a single kernel launch):
naiveRemoteRead -- Simplest possible fusion. Each thread computes one C[row,col] by reading A elements directly from the owner rank's shard via remote pointers. No staging, no explicit gather. Every A element traverses NVLink on every access -- no reuse.
threadloadGatherScalarCompute -- Two-stage fused kernel. Stage 1: cooperative thread loads gather one full A row from the owner's remote shard into a local staging buffer. Stage 2: scalar matmul from the staged row. Inter-rank synchronization via device-side ready/done semaphores (owner signals readiness; non-owners poll; readers ack completion). See threadloadGatherKernel (tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu, line 292).
multimemGatherScalarCompute -- Same two-stage structure, but Stage 1 uses Hopper multimem.st.global.v4.f32 instructions to write A rows to an NVLS multicast buffer, delivering data to all peers in hardware. Requires SM90+ and multicast-capable symmetric memory. See multimemGatherKernel (tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu, line 35).
Two-kernel path (separate comm kernel, then CUTLASS GEMM -- NOT truly fused):
threadloadGatherThenCutlass -- The threadload gather kernel (with semaphores) materializes full A into a staging buffer, then a separate host-launched CUTLASS 3.x SM90 TMA GEMM consumes that buffer. These are two distinct <<<...>>> launches on the same stream. The gather kernel runs with n=0 to skip its in-kernel compute stage.
multimemGatherThenCutlass -- Same as above but using multimem gather instead of threadload gather.
These two-kernel variants establish a performance ceiling: they show what throughput is achievable when the comm pattern is correct and the compute is Hopper-native WGMMA. True single-kernel fusion with equivalent compute quality is the first goal -- see "Where I need the team's input" below.
The two-kernel variants outperform the NCCL baseline by ~20%, achieving 50 TFLOP/s. This validates that our infrastructure works and can beat standard collectives.
The truly fused scalar kernels are 12-15x slower -- the compute is the bottleneck, not the communication. The comm patterns and synchronization model are sound; what's missing is Hopper-native compute inside the fused kernel.
The gap between fused-scalar (4 TFLOP/s) and two-kernel-CUTLASS (50 TFLOP/s) is the central problem. Closing it requires embedding WGMMA and TMA pipelines inside the comm kernel (possibly through cutlass device API).
Synchronization model
Fused kernels require device-side inter-rank synchronization since there is no host between the comm and compute stages. This PR implements epoch-based remote semaphores:
Semaphore buffers are allocated in symmetric memory (visible to all ranks)
The owner rank publishes a monotonically increasing epoch to all peers' semaphore copies via __threadfence_system() + remote writes
Reader ranks poll their local semaphore copy with atomicAdd(..., 0) until the expected epoch appears
After computation, readers publish a "done" epoch back to the owner, who waits before the next row can be reused
Polling is owner-scoped: readers only wait on the specific owner rank, not all ranks.
Where I need the team's input
The central challenge: true single-kernel fusion with Hopper-native compute.
The fused scalar kernels prove that the comm and sync model works. The two-kernel CUTLASS path proves the perf ceiling is high. But achieving both in a single kernel is challenging for me because:
CUTLASS 3.x mainloops are designed to own the entire kernel -- they manage shared memory layout, warpgroup roles (TMA producer vs MMA consumer), and async pipeline barriers. They cannot be called from within another kernel.
TMA descriptors are created on the host via cuTensorMapEncodeTiled. They cannot be created from device code.
WGMMA requires careful warpgroup scheduling -- which warps do MMA, which do data movement, and how shared memory is partitioned between operand staging and communication buffers.
The right approach is likely a custom kernel using CUTE primitives (MMA_Atom, TiledMMA, TMA_LOAD) at the building-block level, weaving P2P comm into the producer/consumer pipeline. This is where I need your expertise:
How would you partition warps between P2P communication and WGMMA compute in a single kernel?
Can TMA descriptors address remote symmetric memory? If so, the gather stage could use cp.async.bulk instead of thread loads, freeing SMs entirely.
What shared memory layout works for staging both the comm buffer (incoming A tiles from remote ranks) and the MMA operand buffers?
How should we pipeline -- gather chunk K of A from remote while computing on chunk K-1? What tile sizes and pipeline depths make sense?
What else is NOT in this PR
TMA-based communication -- Only SM/thread-based NVLink transfers are implemented. TMA would allow non-blocking transfers initiated by a single thread.
Scale-out (InfiniBand) -- Only intra-node NVLink is covered. Cross-node would require nvSHMEM, NIXL, or NCCL-GIN.
Codegen integration -- These are handwritten kernels. The long-term goal is for nvFuser to generate such patterns automatically.
The devicePeerPointers() method performs CUDA memory allocation and memcpy without comprehensive error handling. Consider adding try-catch blocks or NVF_CHECK for all CUDA operations to ensure proper cleanup on failure paths.
The waitOne() and waitAll() functions use atomic operations with a hardcoded kMaxPoll limit. While this prevents infinite loops, the trap instruction may be too aggressive for production use. Consider adding more graceful error handling or logging for timeout scenarios.
__device__inlinevoidwaitOne(
int32_t* local,
int64_t row,
int64_t m,
int64_t writer,
int32_t epoch) {
auto* p = reinterpret_cast<unsignedint*>(local + (writer * m + row) * kVecW);
int64_t s = 0;
while (atomicAdd(p, 0U) < (unsigned)epoch)
if (++s > kMaxPoll)
asmvolatile("trap;");
}
__device__inlinevoidwaitAll(
int32_t* local,
int64_t row,
int64_t m,
int64_t ws,
int32_t epoch) {
for (int64_t r = 0; r < ws; ++r) {
auto* p = reinterpret_cast<unsignedint*>(local + (r * m + row) * kVecW);
int64_t s = 0;
while (atomicAdd(p, 0U) < (unsigned)epoch)
if (++s > kMaxPoll)
asmvolatile("trap;");
}
}
The multimemGatherKernel contains architecture-specific inline assembly for SM90+. While appropriate for experimental code, ensure proper feature detection and fallback paths for non-Hopper architectures to prevent runtime failures.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Linked with Issue
Motivation
nvFuser currently achieves multi-GPU overlap by scheduling separate communication and compute kernels through Host IR, using stream parallelism. This works, but the overlap granularity is coarse, and with this approach communications necessarily represent kernel fusion boundaries.
This PR explores a different approach: GPU-initiated communication inside compute kernels. Instead of the host orchestrating separate comm and compute phases, a single CUDA kernel reads/writes remote GPU memory directly via symmetric memory pointers, interleaving data movement and computation at the thread level.
This is an experimental reference PR -- not intended for merge as-is, but as a self-contained, readable codebase for the team to study, benchmark, reproduce, and iterate on. The fused scalar kernels demonstrate the comm patterns and synchronization model. The two-kernel CUTLASS variants establish a performance ceiling. Closing the gap -- achieving CUTLASS-level compute inside a truly fused single kernel -- is the central open problem where we need the compute team's expertise.
For simplicity, we focus on "Allgather+Matmul" problem, on single H100 node NVLink.
How to run
Requires Hopper (SM90) for CUTLASS and multimem variants. Build with flag like
TORCH_CUDA_ARCH_LIST="9.0a"What this PR contains
A self-contained benchmark comparing 7 distributed matmul implementations for
C[M,N] = A[M,K] x B[K,N]whereAis row-sharded across ranks on axisM,Bis replicated. All code lives in 3 test files:test_multidevice_fused_remote_matmul.h: Shared types, enum, context struct, perf summarytest_multidevice_fused_remote_matmul_kernel.cu: CUDA kernels, CUTLASS wrapper, launcherstest_multidevice_fused_remote_matmul.cpp: Test harness, resource setup, timing, baselinesSmall infrastructure change:
SymmetricTensor::devicePeerPointers()added tosymmetric_tensor.{h,cpp}-- lazily allocates a device-side pointer table for convenient kernel access to peer buffers.Implementations
Baselines (separate allgather + matmul, no fusion):
baselineNcclAllgatherMatmul-- NCCL allgather to rebuild full A, thenat::matmul. The standard-library reference.baselineCudaAllgatherMatmul-- Same pattern using nvFuser's native backend for the allgather, using multicast NVLSTruly fused kernels (comm + compute in a single kernel launch):
naiveRemoteRead-- Simplest possible fusion. Each thread computes oneC[row,col]by reading A elements directly from the owner rank's shard via remote pointers. No staging, no explicit gather. Every A element traverses NVLink on every access -- no reuse.threadloadGatherScalarCompute-- Two-stage fused kernel. Stage 1: cooperative thread loads gather one full A row from the owner's remote shard into a local staging buffer. Stage 2: scalar matmul from the staged row. Inter-rank synchronization via device-side ready/done semaphores (owner signals readiness; non-owners poll; readers ack completion). SeethreadloadGatherKernel(tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu, line 292).multimemGatherScalarCompute-- Same two-stage structure, but Stage 1 uses Hoppermultimem.st.global.v4.f32instructions to write A rows to an NVLS multicast buffer, delivering data to all peers in hardware. Requires SM90+ and multicast-capable symmetric memory. SeemultimemGatherKernel(tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu, line 35).Two-kernel path (separate comm kernel, then CUTLASS GEMM -- NOT truly fused):
threadloadGatherThenCutlass-- The threadload gather kernel (with semaphores) materializes full A into a staging buffer, then a separate host-launched CUTLASS 3.x SM90 TMA GEMM consumes that buffer. These are two distinct<<<...>>>launches on the same stream. The gather kernel runs withn=0to skip its in-kernel compute stage.multimemGatherThenCutlass-- Same as above but using multimem gather instead of threadload gather.These two-kernel variants establish a performance ceiling: they show what throughput is achievable when the comm pattern is correct and the compute is Hopper-native WGMMA. True single-kernel fusion with equivalent compute quality is the first goal -- see "Where I need the team's input" below.
Performance (8xH100 DGX, M=N=K=1024, half precision)
Key observations:
Synchronization model
Fused kernels require device-side inter-rank synchronization since there is no host between the comm and compute stages. This PR implements epoch-based remote semaphores:
__threadfence_system()+ remote writesatomicAdd(..., 0)until the expected epoch appearsWhere I need the team's input
The central challenge: true single-kernel fusion with Hopper-native compute.
The fused scalar kernels prove that the comm and sync model works. The two-kernel CUTLASS path proves the perf ceiling is high. But achieving both in a single kernel is challenging for me because:
CUTLASS 3.x mainloops are designed to own the entire kernel -- they manage shared memory layout, warpgroup roles (TMA producer vs MMA consumer), and async pipeline barriers. They cannot be called from within another kernel.
TMA descriptors are created on the host via
cuTensorMapEncodeTiled. They cannot be created from device code.WGMMA requires careful warpgroup scheduling -- which warps do MMA, which do data movement, and how shared memory is partitioned between operand staging and communication buffers.
The right approach is likely a custom kernel using CUTE primitives (
MMA_Atom,TiledMMA,TMA_LOAD) at the building-block level, weaving P2P comm into the producer/consumer pipeline. This is where I need your expertise:cp.async.bulkinstead of thread loads, freeing SMs entirely.What else is NOT in this PR