From 361723909769702a3f8bb10d12b952eec98823c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 04:57:26 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/comm_gemm_overlap.h | 680 +++++++++--------- 1 file changed, 342 insertions(+), 338 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index df3e0c58a7..a7d9eddea8 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -966,378 +966,382 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int recv_offset = comm_bytes * recv_chunk_id; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; if (_use_fused_sendrecv) { - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, - _ub_comm, peer_rank, peer_rank, (cudaStream_t)_stream_send); + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, _ub_comm, + peer_rank, peer_rank, (cudaStream_t)_stream_send); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); - else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); - } - - int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; - const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; - const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; - - // Ring exchange of 2X inputs chunks - for (int i = 0; i < num_steps; i++) { - send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; - recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; - send_offset = comm_bytes * send_chunk_id; - recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor input_b_chunk = - torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options()); - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk * 2, m}, pre_gelu_out.options()); + else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + peer_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + peer_rank, (cudaStream_t)_stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - if (i < num_steps - 1) { - // P2P communication - if (_use_fused_sendrecv) { - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes * 2, - _ub_comm, next_rank, prev_rank, (cudaStream_t)_stream_send); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + torch::Tensor input_b_chunk = + torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options()); + torch::Tensor output_chunk = torch::from_blob( + output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); + if (do_gelu) { + pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), + {n_chunk * 2, m}, pre_gelu_out.options()); + } + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, + transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, + grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + if (i < num_steps - 1) { + // P2P communication + if (_use_fused_sendrecv) { + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes * 2, + _ub_comm, next_rank, prev_rank, (cudaStream_t)_stream_send); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, (cudaStream_t)_stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); } - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - } - } - } else { - for (int i = 0; i < _tp_size; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; - int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk, m}, pre_gelu_out.options()); } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, B_type, - transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - if (i < _tp_size - 1) { - // P2P communication - if (_use_fused_sendrecv) { - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, - _ub_comm, _next_rank, _prev_rank, (cudaStream_t)_stream_send); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else { + for (int i = 0; i < _tp_size; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + torch::Tensor output_chunk = torch::from_blob( + output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); + if (do_gelu) { + pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), + {n_chunk, m}, pre_gelu_out.options()); + } + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, + B_type, transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, + pre_gelu_out, grad, workspace_chunk, workspace_size_chunk, accumulate, + use_split_accumulator, _math_sms); + + if (i < _tp_size - 1) { + // P2P communication + if (_use_fused_sendrecv) { + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, _ub_comm, + _next_rank, _prev_rank, (cudaStream_t)_stream_send); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, (cudaStream_t)_stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); } - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); } } - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - if (!_use_fused_sendrecv) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - } - at::cuda::setCurrentCUDAStream(stream_main); - _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + if (!_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + } + at::cuda::setCurrentCUDAStream(stream_main); + _ub_comm->sms = ori_sms; - return D; - } // split_overlap_ag + return D; + } // split_overlap_ag - /* + /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Get communication and GEMM input chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - - // Get input and workspace data pointers - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + void atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Get input and workspace data pointers + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int *counter_ptr = reinterpret_cast(counter.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + // Catch up the main stream + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + // Atomic GEMM + // Process GEMM chunks in the order that AG+GEMM places the output chunks. + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, _ubuf, + D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, + workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, 0, + _tp_size, true, counter); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + send_rank, (cudaStream_t)_stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + recv_rank, (cudaStream_t)_stream_recv); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - // Atomic GEMM - // Process GEMM chunks in the order that AG+GEMM places the output chunks. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, _ubuf, - D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, - workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, 0, _tp_size, - true, counter); - - // P2P communication chunk - for (int i = 1; i < _tp_size; i++) { - int send_chunk_id = i - 1; - int recv_chunk_id = send_chunk_id + _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - - consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - (cudaStream_t)_stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, - (cudaStream_t)_stream_recv); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D_type, fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, + _tp_size, _ubufs[0].numel(), + (cudaStream_t)stream_main);); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } + _ub_comm->sms = ori_sms; } - _ub_comm->sms = ori_sms; - } - /* + /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - int k = A.size(1); - int n = B.size(0); - - // Get communication and GEMM input chunk sizes - int n_chunk = n / _tp_size; - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); - - // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - if (!_use_fused_sendrecv) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } + void split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + int k = A.size(1); + int n = B.size(0); + + // Get communication and GEMM input chunk sizes + int n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + // Catch up the main stream + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + if (!_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + } - // GEMM and send/recv chunks - for (int i = 0; i < _tp_size; i++) { - // GEMM chunk - int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); - // Store the last GEMM chunk output to the recieve buffer. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); + torch::Tensor input_b_chunk = + torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); + // Store the last GEMM chunk output to the recieve buffer. + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, + _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); - if (i > 0) { - // P2P communication chunk - int send_offset = comm_bytes * (i - 1); - int recv_offset = comm_bytes * (i - 1 + _tp_size); - int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - if (_use_fused_sendrecv) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, - _ub_comm, send_rank, recv_rank, (cudaStream_t)_stream_send); - } else { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - send_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - recv_rank, (cudaStream_t)_stream_recv); + if (i > 0) { + // P2P communication chunk + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + NVTE_CHECK_CUDA(cudaEventRecord( + _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); + if (_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm, + send_rank, recv_rank, (cudaStream_t)_stream_send); + } else { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + send_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + recv_rank, (cudaStream_t)_stream_recv); + } } } - } - at::cuda::setCurrentCUDAStream(stream_main); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - if (!_use_fused_sendrecv) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - } + at::cuda::setCurrentCUDAStream(stream_main); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + if (!_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + } - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D_type, fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, + _tp_size, _ubufs[0].numel(), + (cudaStream_t)stream_main);); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } + _ub_comm->sms = ori_sms; } - _ub_comm->sms = ori_sms; - } - /* + /* ** Copy input to _ubufs[0] */ - void copy_input_to_ubuf(torch::Tensor input, bool chunk) { - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { - // Copy input to the target ubuf chunk by rank offset - if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); - } else { - if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); + void copy_input_to_ubuf(torch::Tensor input, bool chunk) { + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + if (chunk) { + // Copy input to the target ubuf chunk by rank offset + if (input.numel() != _ubufs[0].numel() || + input.element_size() != _ubufs[0].element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + } else { + if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } - } - torch::Tensor get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); - if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - } + torch::Tensor get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); + COMM_TYPE _comm_type = static_cast(comm_type); + if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) + NVTE_ERROR("Invalid comm_type"); + if (_comm_type == COMM_TYPE::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); + int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); + } - void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; - } + void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } - bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } - bool is_atomic_gemm() { return _atomic_gemm; } - bool is_p2p_overlap() { return true; } -}; // UbufP2PCommOverlap + bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } + bool is_atomic_gemm() { return _atomic_gemm; } + bool is_p2p_overlap() { return true; } + }; // UbufP2PCommOverlap } // namespace ubuf