Skip to content

[Multidevice] Handwritten distributed matmul kernels -- reference implementations for fused comm/compute kernels#6002

Draft
samnordmann wants to merge 23 commits intomainfrom
custom_matmul
Draft

[Multidevice] Handwritten distributed matmul kernels -- reference implementations for fused comm/compute kernels#6002
samnordmann wants to merge 23 commits intomainfrom
custom_matmul

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Feb 23, 2026

Linked with Issue

Motivation

nvFuser currently achieves multi-GPU overlap by scheduling separate communication and compute kernels through Host IR, using stream parallelism. This works, but the overlap granularity is coarse, and with this approach communications necessarily represent kernel fusion boundaries.

This PR explores a different approach: GPU-initiated communication inside compute kernels. Instead of the host orchestrating separate comm and compute phases, a single CUDA kernel reads/writes remote GPU memory directly via symmetric memory pointers, interleaving data movement and computation at the thread level.

This is an experimental reference PR -- not intended for merge as-is, but as a self-contained, readable codebase for the team to study, benchmark, reproduce, and iterate on. The fused scalar kernels demonstrate the comm patterns and synchronization model. The two-kernel CUTLASS variants establish a performance ceiling. Closing the gap -- achieving CUTLASS-level compute inside a truly fused single kernel -- is the central open problem where we need the compute team's expertise.

For simplicity, we focus on "Allgather+Matmul" problem, on single H100 node NVLink.

How to run

mpirun -np 8 test_multidevice --gtest_filter=*FusedRemoteMatmulTest*

Requires Hopper (SM90) for CUTLASS and multimem variants. Build with flag like TORCH_CUDA_ARCH_LIST="9.0a"

What this PR contains

A self-contained benchmark comparing 7 distributed matmul implementations for C[M,N] = A[M,K] x B[K,N] where A is row-sharded across ranks on axis M, B is replicated. All code lives in 3 test files:

  • test_multidevice_fused_remote_matmul.h: Shared types, enum, context struct, perf summary
  • test_multidevice_fused_remote_matmul_kernel.cu: CUDA kernels, CUTLASS wrapper, launchers
  • test_multidevice_fused_remote_matmul.cpp: Test harness, resource setup, timing, baselines

Small infrastructure change: SymmetricTensor::devicePeerPointers() added to symmetric_tensor.{h,cpp} -- lazily allocates a device-side pointer table for convenient kernel access to peer buffers.

Implementations

Baselines (separate allgather + matmul, no fusion):

  • baselineNcclAllgatherMatmul -- NCCL allgather to rebuild full A, then at::matmul. The standard-library reference.
  • baselineCudaAllgatherMatmul -- Same pattern using nvFuser's native backend for the allgather, using multicast NVLS

Truly fused kernels (comm + compute in a single kernel launch):

  • naiveRemoteRead -- Simplest possible fusion. Each thread computes one C[row,col] by reading A elements directly from the owner rank's shard via remote pointers. No staging, no explicit gather. Every A element traverses NVLink on every access -- no reuse.

  • threadloadGatherScalarCompute -- Two-stage fused kernel. Stage 1: cooperative thread loads gather one full A row from the owner's remote shard into a local staging buffer. Stage 2: scalar matmul from the staged row. Inter-rank synchronization via device-side ready/done semaphores (owner signals readiness; non-owners poll; readers ack completion). See threadloadGatherKernel (tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu, line 292).

  • multimemGatherScalarCompute -- Same two-stage structure, but Stage 1 uses Hopper multimem.st.global.v4.f32 instructions to write A rows to an NVLS multicast buffer, delivering data to all peers in hardware. Requires SM90+ and multicast-capable symmetric memory. See multimemGatherKernel (tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu, line 35).

Two-kernel path (separate comm kernel, then CUTLASS GEMM -- NOT truly fused):

  • threadloadGatherThenCutlass -- The threadload gather kernel (with semaphores) materializes full A into a staging buffer, then a separate host-launched CUTLASS 3.x SM90 TMA GEMM consumes that buffer. These are two distinct <<<...>>> launches on the same stream. The gather kernel runs with n=0 to skip its in-kernel compute stage.

  • multimemGatherThenCutlass -- Same as above but using multimem gather instead of threadload gather.

These two-kernel variants establish a performance ceiling: they show what throughput is achievable when the comm pattern is correct and the compute is Hopper-native WGMMA. True single-kernel fusion with equivalent compute quality is the first goal -- see "Where I need the team's input" below.

Performance (8xH100 DGX, M=N=K=1024, half precision)

Implementation                      | ms/iter  | TFLOP/s | Fusion
------------------------------------+----------+---------+--------
baselineNcclAllgatherMatmul         |  0.052   |  41.4   | none
baselineCudaAllgatherMatmul         |  0.093   |  23.2   | none
naiveRemoteRead                     |  0.808   |   2.66  | fused
threadloadGatherScalarCompute       |  0.531   |   4.04  | fused
multimemGatherScalarCompute         |  0.588   |   3.65  | fused
threadloadGatherThenCutlass         |  0.043   |  50.3   | "two-kernel"
multimemGatherThenCutlass           |  0.043   |  50.1   | "two-kernel"

Key observations:

  • The two-kernel variants outperform the NCCL baseline by ~20%, achieving 50 TFLOP/s. This validates that our infrastructure works and can beat standard collectives.
  • The truly fused scalar kernels are 12-15x slower -- the compute is the bottleneck, not the communication. The comm patterns and synchronization model are sound; what's missing is Hopper-native compute inside the fused kernel.
  • The gap between fused-scalar (4 TFLOP/s) and two-kernel-CUTLASS (50 TFLOP/s) is the central problem. Closing it requires embedding WGMMA and TMA pipelines inside the comm kernel (possibly through cutlass device API).

Synchronization model

Fused kernels require device-side inter-rank synchronization since there is no host between the comm and compute stages. This PR implements epoch-based remote semaphores:

  • Semaphore buffers are allocated in symmetric memory (visible to all ranks)
  • The owner rank publishes a monotonically increasing epoch to all peers' semaphore copies via __threadfence_system() + remote writes
  • Reader ranks poll their local semaphore copy with atomicAdd(..., 0) until the expected epoch appears
  • After computation, readers publish a "done" epoch back to the owner, who waits before the next row can be reused
  • Polling is owner-scoped: readers only wait on the specific owner rank, not all ranks.

Where I need the team's input

The central challenge: true single-kernel fusion with Hopper-native compute.

The fused scalar kernels prove that the comm and sync model works. The two-kernel CUTLASS path proves the perf ceiling is high. But achieving both in a single kernel is challenging for me because:

  1. CUTLASS 3.x mainloops are designed to own the entire kernel -- they manage shared memory layout, warpgroup roles (TMA producer vs MMA consumer), and async pipeline barriers. They cannot be called from within another kernel.

  2. TMA descriptors are created on the host via cuTensorMapEncodeTiled. They cannot be created from device code.

  3. WGMMA requires careful warpgroup scheduling -- which warps do MMA, which do data movement, and how shared memory is partitioned between operand staging and communication buffers.

The right approach is likely a custom kernel using CUTE primitives (MMA_Atom, TiledMMA, TMA_LOAD) at the building-block level, weaving P2P comm into the producer/consumer pipeline. This is where I need your expertise:

  • How would you partition warps between P2P communication and WGMMA compute in a single kernel?
  • Can TMA descriptors address remote symmetric memory? If so, the gather stage could use cp.async.bulk instead of thread loads, freeing SMs entirely.
  • What shared memory layout works for staging both the comm buffer (incoming A tiles from remote ranks) and the MMA operand buffers?
  • How should we pipeline -- gather chunk K of A from remote while computing on chunk K-1? What tile sizes and pipeline depths make sense?

What else is NOT in this PR

  • TMA-based communication -- Only SM/thread-based NVLink transfers are implemented. TMA would allow non-blocking transfers initiated by a single thread.
  • Scale-out (InfiniBand) -- Only intra-node NVLink is covered. Cross-node would require nvSHMEM, NIXL, or NCCL-GIN.
  • Codegen integration -- These are handwritten kernels. The long-term goal is for nvFuser to generate such patterns automatically.

@github-actions
Copy link

Description

  • Adds comprehensive distributed matmul benchmark with 7 different implementations comparing communication and computation strategies

  • Implements GPU-initiated communication inside compute kernels using remote memory access and synchronization primitives

  • Integrates CUTLASS SM90 TMA matmul for high-performance baseline comparisons

  • Adds device peer pointer support in SymmetricTensor for efficient remote memory access

Changes walkthrough

Relevant files
Enhancement
test_multidevice_fused_remote_matmul.h
Header with distributed matmul types and declarations       

tests/cpp/test_multidevice_fused_remote_matmul.h

  • Defines DistributedMatmulImpl enum with 7 implementation variants
  • Defines BenchmarkConfig and DistributedMatmulContext structures
  • Documents performance benchmarks and implementation strategies
  • Declares kernel launcher functions and utility functions
  • +125/-0 
    test_multidevice_fused_remote_matmul.cpp
    Test harness and benchmark implementation                               

    tests/cpp/test_multidevice_fused_remote_matmul.cpp

  • Implements test harness with timing utilities and resource management
  • Provides benchmarkLoopMs and batchedKernelTimeMs for performance
    measurement
  • Implements runImplementation dispatcher for all 7 variants
  • Includes correctness validation and performance reporting
  • +470/-0 
    test_multidevice_fused_remote_matmul_kernel.cu
    CUDA kernels for distributed matmul implementations           

    tests/cpp/test_multidevice_fused_remote_matmul_kernel.cu

  • Implements CUTLASS SM90 TMA matmul wrapper for high-performance
    baseline
  • Implements naiveRemoteReadKernel for direct remote memory access
  • Implements threadloadGatherKernel with synchronized P2P gather and
    scalar compute
  • Implements multimemGatherKernel using Hopper multimem stores and
    scalar compute
  • Provides semaphore-based synchronization primitives for inter-rank
    coordination
  • +604/-0 
    symmetric_tensor.cpp
    Add device peer pointer support in SymmetricTensor             

    csrc/multidevice/symmetric_tensor.cpp

  • Adds devicePeerPointers() method to return device-side table of peer
    pointers
  • Adds proper cleanup in destructor for device_peer_ptrs_
  • Implements lazy initialization of device peer pointer table
  • +23/-0   
    symmetric_tensor.h
    SymmetricTensor header updates for device peer pointers   

    csrc/multidevice/symmetric_tensor.h

  • Adds devicePeerPointers() method declaration
  • Adds device_peer_ptrs_ member variable for caching device pointer
    table
  • +3/-0     
    Configuration changes
    CMakeLists.txt
    Build system updates for CUTLASS integration and new tests

    CMakeLists.txt

  • Adds CUTLASS include directories to codegen_internal target
  • Adds new test files to BUILD_TEST compilation
  • Updates compiler flags to use generator expressions for CXX language
  • +10/-4   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Memory Management

    The devicePeerPointers() method performs CUDA memory allocation and memcpy without comprehensive error handling. Consider adding try-catch blocks or NVF_CHECK for all CUDA operations to ensure proper cleanup on failure paths.

    void** SymmetricTensor::devicePeerPointers() const {
      NVF_CHECK(are_remote_tensors_setup_ == true, "Remote tensors not setup");
      if (device_peer_ptrs_ == nullptr) {
        std::vector<void*> host_peer_ptrs(world_size_);
        for (int64_t rank = 0; rank < world_size_; ++rank) {
          host_peer_ptrs[rank] = reinterpret_cast<void*>(remote_ptrs_[rank]);
        }
        NVFUSER_CUDA_RT_SAFE_CALL(
            cudaMalloc(&device_peer_ptrs_, world_size_ * sizeof(void*)));
        NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
            device_peer_ptrs_,
            host_peer_ptrs.data(),
            world_size_ * sizeof(void*),
            cudaMemcpyHostToDevice));
      }
      return device_peer_ptrs_;
    }
    Synchronization Robustness

    The waitOne() and waitAll() functions use atomic operations with a hardcoded kMaxPoll limit. While this prevents infinite loops, the trap instruction may be too aggressive for production use. Consider adding more graceful error handling or logging for timeout scenarios.

    __device__ inline void waitOne(
        int32_t* local,
        int64_t row,
        int64_t m,
        int64_t writer,
        int32_t epoch) {
      auto* p = reinterpret_cast<unsigned int*>(local + (writer * m + row) * kVecW);
      int64_t s = 0;
      while (atomicAdd(p, 0U) < (unsigned)epoch)
        if (++s > kMaxPoll)
          asm volatile("trap;");
    }
    
    __device__ inline void waitAll(
        int32_t* local,
        int64_t row,
        int64_t m,
        int64_t ws,
        int32_t epoch) {
      for (int64_t r = 0; r < ws; ++r) {
        auto* p = reinterpret_cast<unsigned int*>(local + (r * m + row) * kVecW);
        int64_t s = 0;
        while (atomicAdd(p, 0U) < (unsigned)epoch)
          if (++s > kMaxPoll)
            asm volatile("trap;");
      }
    }
    Architecture-Specific Code

    The multimemGatherKernel contains architecture-specific inline assembly for SM90+. While appropriate for experimental code, ensure proper feature detection and fallback paths for non-Hopper architectures to prevent runtime failures.

            asm volatile(
                "multimem.st.global.v4.f32 [%0],"
                " {%1, %2, %3, %4};"
                :
                : "l"((void*)(arow + vi * kVec)),
                  "f"(__int_as_float((int)val.x)),
                  "f"(__int_as_float((int)val.y)),
                  "f"(__int_as_float((int)val.z)),
                  "f"(__int_as_float((int)val.w))
                : "memory");
    #else
            (void)val;
            asm volatile("trap;");
    #endif
          }
          for (int64_t kk = nvec * kVec + threadIdx.x; kk < k; kk += blockDim.x)
            arow[kk] = a[lr * k + kk];
        }
        __syncthreads();
    
        // --- Semaphore barrier ---
    #if __CUDA_ARCH__ >= 900
        const int32_t epoch = epoch_base + 1;
        if (threadIdx.x == 0 && rank == owner)
          publishToAll(sem_r, sem_l, rank, row, m, ws, epoch);
        __syncthreads();
        if (threadIdx.x == 0 && rank != owner)
          waitOne(sem_l, row, m, owner, epoch);
        __syncthreads();
    #else
        (void)sem_r;
        (void)sem_l;
        (void)rank;
        (void)ws;
        (void)epoch_base;
        asm volatile("trap;");
    #endif

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant