-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Problem
nvFuser currently handles multi-GPU execution by lowering to Host IR, which orchestrates separate communication kernels (e.g. NCCL) and compute kernels on different streams. Overlap is achieved by pipelining at the stream level. This has fundamental granularity limitations. Also, host-initiated communications represent necessary kernel fusion boundaries.
Proposed direction: GPU-initiated fused comm+compute
With CUDA symmetric memory, GPUs can read/write each other's memory directly, without host intervention. This enables fused kernels where a single CUDA kernel interleaves communication (remote memory access) and computation at the warp or thread-block level.
PR #6002 provides a first set of handwritten reference implementations demonstrating this for a distributed matmul (allgather + GEMM). It includes:
- 3 truly fused kernels (comm + scalar compute in one launch) that validate the communication patterns and device-side semaphore synchronization
- 2 two-kernel paths (separate P2P gather kernel, then CUTLASS TMA GEMM) that establish the performance ceiling at 50 TFLOP/s on 8xH100, outperforming NCCL by 20%
The gap between the fused-scalar kernels (4 TFLOP/s) and the two-kernel CUTLASS path (50 TFLOP/s) is the core problem. The comm infrastructure works; what's missing is Hopper-native compute (WGMMA) inside the fused kernel.
Key open problems (input needed)
1. True single-kernel fusion with WGMMA compute
The CUTLASS 3.x mainloop owns the entire kernel (shared memory, warpgroup roles, async pipeline). We cannot call it from inside a comm kernel. The path forward is a custom kernel using CUTE building blocks (MMA_Atom, TiledMMA, TMA_LOAD) that weaves P2P comm into the warpgroup pipeline. Concrete questions:
- How to partition warps between P2P data movement and WGMMA?
- What shared memory layout supports both comm staging and MMA operand buffers?
- What pipeline depth and tile sizes balance comm latency hiding with compute throughput?
2. TMA for P2P communication
Can TMA descriptors address remote symmetric memory pointers? If yes, this would allow non-blocking P2P transfers from a single thread, freeing all other warps for compute. This could be transformative for the fused kernel design.
3. Pipelined execution
Current implementations gather all of A, then compute all of C (or run them as separate kernels). A pipelined approach -- gather tile K while computing on tile K-1 -- would hide comm latency under compute. This maps naturally to the CUTLASS async pipeline model but needs to handle cross-rank data dependencies.
Pytorch has a very interesting implementation of this mechanism written in pure Cutlass. IIUC, this achieves the same as our pipeline circular p2p algorithm.
4. Codegen
It would be great to eventually be able to generate such fused kernels. What new scheduling primitives would be needed to generate these patterns automatically? How should nvFuser's represent:
- Cross-device data dependencies ("this tile comes from rank X")
- GPU-initiated communication operations within a kernel
- Warp role assignments (comm vs compute warps)
I'm guessing these concept already exists in single-GPU codegen infrastructure, but will definitely need the team's help to understand and adapt it to multi-GPU.
5. Beyond AG+Matmul
Allgather+matmul is the starting pattern, but reduce-scatter after matmul, MM+all-reduce, and especially MoE routing could all benefit from fusion. Which patterns should we target next?
Note on communication transport options (NVLink domain)
| Transport | Mechanism | Blocking? | SM cost | Status |
|---|---|---|---|---|
| SM/thread loads/stores | Regular thread memory ops via remote pointers | Blocking | High | Implemented in PR |
TMA (cp.async.bulk) |
Single-thread DMA via TMA descriptors | Non-blocking | Very low | Not implemented yet. Open question whether it works on multicast ptr |
NVLS multicast (multimem.st) |
Hardware broadcast via NVSwitch | Blocking | Medium | Implemented in PR |
Copy engine (cudaMemcpyAsync) |
Host-initiated DMA | N/A (host) | Zero | Host initiated -- Not relevant for fused kernels |
For scale-out (InfiniBand), integration with nvSHMEM, NIXL, or NCCL-GIN would be needed.