Skip to content

Conversation

@denera
Copy link
Collaborator

@denera denera commented Nov 7, 2025

Description

This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera changed the title first pass at cuBlasMp integration into CollectiveGemm [JAX] cuBlasMp integration for CollectiveGemm custom op Nov 7, 2025
@denera denera self-assigned this Nov 7, 2025
@ptrendx ptrendx added the 2.10.0 label Nov 14, 2025
@denera denera marked this pull request as ready for review November 14, 2025 19:14
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 14, 2025

Greptile Overview

Greptile Summary

This PR integrates cuBLASMp library bindings into the JAX CollectiveGemm custom operation to enable communication-computation overlap for distributed GEMM operations.

Critical Issues Found:

  • Circular typedef in cgemm_helper.h - CollectiveGemmCtx is undefined, should map to CommOverlapCore or NVTECommGemmCtx
  • Function signature mismatch in cgemm_helper.cpp - passing 4 arguments to nvte_comm_gemm_ctx_create which only accepts 3
  • Multiple syntax errors in gemm.cpp cuBLASMp code paths:
    • Undefined variables out_ and _ctx used in function calls
    • Attempting to modify const int k and const int n variables
    • Wrong enum constant CUBLASMP_MATMUL_ALGO_SPLIT_P2P instead of kNVTECommGemmAlgoSplitP2P

Implementation Approach:

  • Added conditional compilation flag NVTE_WITH_CUBLASMP to build system
  • Refactored get_executor to get_context to accommodate both Userbuffers and cuBLASMp implementations
  • Added cuBLASMp code paths for REDUCE_SCATTER and ALL_GATHER collective operations

The code will not compile in its current state when NVTE_WITH_CUBLASMP is enabled.

Confidence Score: 0/5

  • This PR cannot be merged - it contains multiple critical compilation errors that will prevent the code from building
  • Score of 0 reflects that the cuBLASMp integration code paths contain 8+ syntax/compilation errors including undefined variables, circular typedefs, function signature mismatches, const modification violations, and wrong enum constants. The code will fail to compile when NVTE_WITH_CUBLASMP=1 is set.
  • All 3 C++ files require fixes: cgemm_helper.h (typedef), cgemm_helper.cpp (function call), and gemm.cpp (multiple errors in both REDUCE_SCATTER and ALL_GATHER code paths)

Important Files Changed

File Analysis

Filename Score Overview
build_tools/jax.py 5/5 added conditional compilation flag for cuBLASMp support via NVTE_WITH_CUBLASMP environment variable
transformer_engine/jax/csrc/extensions/cgemm_helper.h 0/5 circular typedef creates compilation error - CollectiveGemmCtx is undefined, causing both typedefs to fail
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp 0/5 function call passes 4 arguments but nvte_comm_gemm_ctx_create only accepts 3 parameters (comm, nranks, rank)
transformer_engine/jax/csrc/extensions/gemm.cpp 0/5 multiple compilation errors in cuBLASMp code paths: undefined variables (out_, _ctx), const modification attempts (k, n), and wrong enum constants (CUBLASMP_MATMUL_ALGO_SPLIT_P2P)

Sequence Diagram

sequenceDiagram
    participant User as JAX User Code
    participant Init as CollectiveGemmInitFFI
    participant Registry as CollectiveGemmPlanRegistry
    participant Handler as CommunicatorHandler
    participant GemmFFI as GemmFFI
    participant cuBlasMp as cuBLASMp Context
    participant UB as Userbuffers (P2P)

    Note over User,UB: Initialization Phase
    User->>Init: CollectiveGemmInitFFI(buffer_shape, dtype, collective_op)
    Init->>Registry: get_context(buffer_shape, dtype, collective_op)
    
    alt NVTE_WITH_CUBLASMP
        Registry->>Handler: get_comm_for_current_device()
        Registry->>cuBlasMp: nvte_comm_gemm_ctx_create(comm, nranks, rank)
        cuBlasMp-->>Registry: ctx (NVTECommGemmCtx)
    else Without cuBLASMp
        Registry->>UB: new CommOverlapP2PBase(...)
        UB-->>Registry: ctx (CommOverlapCore)
    end
    
    Registry-->>Init: context cached in plan_map
    
    Note over User,UB: Computation Phase
    User->>GemmFFI: GemmFFI(lhs, rhs, collective_op)
    GemmFFI->>Registry: get_context(buffer_shape, dtype, collective_op)
    Registry-->>GemmFFI: ctx (from cache)
    
    alt REDUCE_SCATTER with cuBLASMp
        GemmFFI->>cuBlasMp: nvte_gemm_reduce_scatter(ctx, m, n, k_global, ...)
        cuBlasMp->>cuBlasMp: Split GEMM + ReduceScatter overlap
        cuBlasMp-->>GemmFFI: result in output buffer
    else ALL_GATHER with cuBLASMp
        GemmFFI->>cuBlasMp: nvte_all_gather_gemm(ctx, m, n_global, k, ...)
        cuBlasMp->>cuBlasMp: AllGather + GEMM overlap
        cuBlasMp-->>GemmFFI: result in output buffer
    else Without cuBLASMp (Userbuffers)
        GemmFFI->>UB: ctx->split_overlap_rs/ag(...)
        UB->>UB: P2P-based comm-compute overlap
        UB-->>GemmFFI: result in output buffer
    end
    
    GemmFFI-->>User: computation complete
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +160 to +164
#ifndef NVTE_WITH_CUBLASMP
typedef CollectiveGemmCtx CommOverlapCore;
#else
typedef CollectiveGemmCtx CommGemmCtx;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: CollectiveGemmCtx is undefined, creating circular typedef

Suggested change
#ifndef NVTE_WITH_CUBLASMP
typedef CollectiveGemmCtx CommOverlapCore;
#else
typedef CollectiveGemmCtx CommGemmCtx;
#endif
#ifndef NVTE_WITH_CUBLASMP
typedef CommOverlapCore CollectiveGemmCtx;
#else
typedef NVTECommGemmCtx CollectiveGemmCtx;
#endif

Comment on lines +205 to +207
ctx = nvte_comm_gemm_ctx_create(comm_handler.get_comm_for_current_device(),
comm_handler.num_total_devices, comm_handler.get_global_rank(),
te::cuda::current_device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: function signature mismatch - nvte_comm_gemm_ctx_create takes 3 parameters (comm, nranks, rank) but 4 are passed here

Suggested change
ctx = nvte_comm_gemm_ctx_create(comm_handler.get_comm_for_current_device(),
comm_handler.num_total_devices, comm_handler.get_global_rank(),
te::cuda::current_device());
ctx = nvte_comm_gemm_ctx_create(comm_handler.get_comm_for_current_device(),
comm_handler.num_total_devices, comm_handler.get_global_rank());

const int n = (transb) ? lhs_shape[1] : lhs_shape[0];
const int k = (transa) ? rhs_shape[1] : rhs_shape[0];

k *= ctx->nranks; // convert contracting dimension to global size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: k is declared const int on line 272, cannot be modified

Suggested change
k *= ctx->nranks; // convert contracting dimension to global size
int k_global = k * ctx->nranks; // convert contracting dimension to global size


k *= ctx->nranks; // convert contracting dimension to global size
NVTE_CHECK_CUBLASMP(
nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: out_ variable is undefined in the NVTE_WITH_CUBLASMP code path (only defined in the #ifndef block above)

Suggested change
nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
nvte_gemm_reduce_scatter(ctx, m, n, k_global, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(),

const int n = (transb) ? lhs_shape[1] : lhs_shape[0];
const int k = (transa) ? rhs_shape[1] : rhs_shape[0];

n *= ctx->nranks; // convert all-gathered dimension to global size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: n is declared const int on line 297, cannot be modified

Suggested change
n *= ctx->nranks; // convert all-gathered dimension to global size
int n_global = n * ctx->nranks; // convert all-gathered dimension to global size


n *= ctx->nranks; // convert all-gathered dimension to global size
NVTE_CHECK_CUBLASMP(
nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: typo: _ctx should be ctx

Suggested change
nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
nvte_all_gather_gemm(ctx, m, n_global, k, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(),

Comment on lines +276 to +278
nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
pre_gelu_.data(), rhs_transposed, lhs_transposed, grad,
use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: wrong enum constant - should use kNVTECommGemmAlgoSplitP2P from NVTECommGemmAlgoType enum instead of undefined CUBLASMP_MATMUL_ALGO_SPLIT_P2P

Suggested change
nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
pre_gelu_.data(), rhs_transposed, lhs_transposed, grad,
use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P));
nvte_gemm_reduce_scatter(ctx, m, n, k_global, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(),
pre_gelu_.data(), rhs_transposed, lhs_transposed, grad,
use_split_accumulator, 0, stream, kNVTECommGemmAlgoSplitP2P));

Comment on lines +302 to +304
nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
pre_gelu_.data(), rhs_transposed, lhs_transposed, grad,
use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: wrong enum constant - should use kNVTECommGemmAlgoSplitP2P from NVTECommGemmAlgoType enum instead of undefined CUBLASMP_MATMUL_ALGO_SPLIT_P2P

Suggested change
nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(),
pre_gelu_.data(), rhs_transposed, lhs_transposed, grad,
use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P));
nvte_all_gather_gemm(ctx, m, n_global, k, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(),
pre_gelu_.data(), rhs_transposed, lhs_transposed, grad,
use_split_accumulator, 0, stream, kNVTECommGemmAlgoSplitP2P));

@phu0ngng
Copy link
Collaborator

Hi, could you add some unit tests?

CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector<size_t> buffer_shape,
DType dtype,
JAXX_Collective_Op collective_op) {
CollectiveGemmCtx *CollectiveGemmPlanRegistry::get_context(std::vector<size_t> buffer_shape,
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's restructure this to instead of a compile-time flag, we compile both backend implementations, similar to TE's support for NormPlanRegistry that can use cuDNN or TE kernels. This way a user can easily switch between the two without needing to rebuild TE/JAX

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants