-
Notifications
You must be signed in to change notification settings - Fork 546
[JAX] cuBlasMp integration for CollectiveGemm custom op #2361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR integrates cuBLASMp library bindings into the JAX CollectiveGemm custom operation to enable communication-computation overlap for distributed GEMM operations. Critical Issues Found:
Implementation Approach:
The code will not compile in its current state when Confidence Score: 0/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as JAX User Code
participant Init as CollectiveGemmInitFFI
participant Registry as CollectiveGemmPlanRegistry
participant Handler as CommunicatorHandler
participant GemmFFI as GemmFFI
participant cuBlasMp as cuBLASMp Context
participant UB as Userbuffers (P2P)
Note over User,UB: Initialization Phase
User->>Init: CollectiveGemmInitFFI(buffer_shape, dtype, collective_op)
Init->>Registry: get_context(buffer_shape, dtype, collective_op)
alt NVTE_WITH_CUBLASMP
Registry->>Handler: get_comm_for_current_device()
Registry->>cuBlasMp: nvte_comm_gemm_ctx_create(comm, nranks, rank)
cuBlasMp-->>Registry: ctx (NVTECommGemmCtx)
else Without cuBLASMp
Registry->>UB: new CommOverlapP2PBase(...)
UB-->>Registry: ctx (CommOverlapCore)
end
Registry-->>Init: context cached in plan_map
Note over User,UB: Computation Phase
User->>GemmFFI: GemmFFI(lhs, rhs, collective_op)
GemmFFI->>Registry: get_context(buffer_shape, dtype, collective_op)
Registry-->>GemmFFI: ctx (from cache)
alt REDUCE_SCATTER with cuBLASMp
GemmFFI->>cuBlasMp: nvte_gemm_reduce_scatter(ctx, m, n, k_global, ...)
cuBlasMp->>cuBlasMp: Split GEMM + ReduceScatter overlap
cuBlasMp-->>GemmFFI: result in output buffer
else ALL_GATHER with cuBLASMp
GemmFFI->>cuBlasMp: nvte_all_gather_gemm(ctx, m, n_global, k, ...)
cuBlasMp->>cuBlasMp: AllGather + GEMM overlap
cuBlasMp-->>GemmFFI: result in output buffer
else Without cuBLASMp (Userbuffers)
GemmFFI->>UB: ctx->split_overlap_rs/ag(...)
UB->>UB: P2P-based comm-compute overlap
UB-->>GemmFFI: result in output buffer
end
GemmFFI-->>User: computation complete
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 8 comments
| #ifndef NVTE_WITH_CUBLASMP | ||
| typedef CollectiveGemmCtx CommOverlapCore; | ||
| #else | ||
| typedef CollectiveGemmCtx CommGemmCtx; | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: CollectiveGemmCtx is undefined, creating circular typedef
| #ifndef NVTE_WITH_CUBLASMP | |
| typedef CollectiveGemmCtx CommOverlapCore; | |
| #else | |
| typedef CollectiveGemmCtx CommGemmCtx; | |
| #endif | |
| #ifndef NVTE_WITH_CUBLASMP | |
| typedef CommOverlapCore CollectiveGemmCtx; | |
| #else | |
| typedef NVTECommGemmCtx CollectiveGemmCtx; | |
| #endif |
| ctx = nvte_comm_gemm_ctx_create(comm_handler.get_comm_for_current_device(), | ||
| comm_handler.num_total_devices, comm_handler.get_global_rank(), | ||
| te::cuda::current_device()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: function signature mismatch - nvte_comm_gemm_ctx_create takes 3 parameters (comm, nranks, rank) but 4 are passed here
| ctx = nvte_comm_gemm_ctx_create(comm_handler.get_comm_for_current_device(), | |
| comm_handler.num_total_devices, comm_handler.get_global_rank(), | |
| te::cuda::current_device()); | |
| ctx = nvte_comm_gemm_ctx_create(comm_handler.get_comm_for_current_device(), | |
| comm_handler.num_total_devices, comm_handler.get_global_rank()); |
| const int n = (transb) ? lhs_shape[1] : lhs_shape[0]; | ||
| const int k = (transa) ? rhs_shape[1] : rhs_shape[0]; | ||
|
|
||
| k *= ctx->nranks; // convert contracting dimension to global size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: k is declared const int on line 272, cannot be modified
| k *= ctx->nranks; // convert contracting dimension to global size | |
| int k_global = k * ctx->nranks; // convert contracting dimension to global size |
|
|
||
| k *= ctx->nranks; // convert contracting dimension to global size | ||
| NVTE_CHECK_CUBLASMP( | ||
| nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: out_ variable is undefined in the NVTE_WITH_CUBLASMP code path (only defined in the #ifndef block above)
| nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), | |
| nvte_gemm_reduce_scatter(ctx, m, n, k_global, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(), |
| const int n = (transb) ? lhs_shape[1] : lhs_shape[0]; | ||
| const int k = (transa) ? rhs_shape[1] : rhs_shape[0]; | ||
|
|
||
| n *= ctx->nranks; // convert all-gathered dimension to global size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: n is declared const int on line 297, cannot be modified
| n *= ctx->nranks; // convert all-gathered dimension to global size | |
| int n_global = n * ctx->nranks; // convert all-gathered dimension to global size |
|
|
||
| n *= ctx->nranks; // convert all-gathered dimension to global size | ||
| NVTE_CHECK_CUBLASMP( | ||
| nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: _ctx should be ctx
| nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), | |
| nvte_all_gather_gemm(ctx, m, n_global, k, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(), |
| nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), | ||
| pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, | ||
| use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: wrong enum constant - should use kNVTECommGemmAlgoSplitP2P from NVTECommGemmAlgoType enum instead of undefined CUBLASMP_MATMUL_ALGO_SPLIT_P2P
| nvte_gemm_reduce_scatter(ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), | |
| pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, | |
| use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P)); | |
| nvte_gemm_reduce_scatter(ctx, m, n, k_global, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(), | |
| pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, | |
| use_split_accumulator, 0, stream, kNVTECommGemmAlgoSplitP2P)); |
| nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), | ||
| pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, | ||
| use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: wrong enum constant - should use kNVTECommGemmAlgoSplitP2P from NVTECommGemmAlgoType enum instead of undefined CUBLASMP_MATMUL_ALGO_SPLIT_P2P
| nvte_all_gather_gemm(_ctx, m, n, k, rhs_.data(), lhs_.data(), out_.data(), bias_.data(), | |
| pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, | |
| use_split_accumulator, 0, stream, CUBLASMP_MATMUL_ALGO_SPLIT_P2P)); | |
| nvte_all_gather_gemm(ctx, m, n_global, k, rhs_.data(), lhs_.data(), output->untyped_data(), bias_.data(), | |
| pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, | |
| use_split_accumulator, 0, stream, kNVTECommGemmAlgoSplitP2P)); |
|
Hi, could you add some unit tests? |
| CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector<size_t> buffer_shape, | ||
| DType dtype, | ||
| JAXX_Collective_Op collective_op) { | ||
| CollectiveGemmCtx *CollectiveGemmPlanRegistry::get_context(std::vector<size_t> buffer_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's restructure this to instead of a compile-time flag, we compile both backend implementations, similar to TE's support for NormPlanRegistry that can use cuDNN or TE kernels. This way a user can easily switch between the two without needing to rebuild TE/JAX
Description
This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op.
Type of change
Checklist: