From 8f8a81e1c3cdfd2cd4a1b1ad4cb47ca4cbfff835 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Feb 2025 01:08:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../include/transformer_engine/fused_attn.h | 14 ++-- transformer_engine/pytorch/csrc/extensions.h | 8 +- .../pytorch/csrc/extensions/attention.cu | 73 ++++++++++--------- transformer_engine/pytorch/csrc/kv_cache.cuh | 46 ++++++------ transformer_engine/pytorch/inference.py | 5 +- 5 files changed, 71 insertions(+), 75 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 62ad226962..3c7b3f5817 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -367,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6689c89d73..515b6fe602 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -70,10 +70,10 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len); -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned); +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, + int d_q, int b, int max_ctx_len, int max_seq_len); +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned); void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 1caf6b15e4..4337065e1f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ #include "extensions.h" -#include "thd_utils.cuh" #include "kv_cache.cuh" +#include "thd_utils.cuh" constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -1036,36 +1036,32 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t template void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { - transformer_engine::fused_attn::reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_q.data_ptr()), - reinterpret_cast(q_buffer.data_ptr()), cu_new_lens.data_ptr(), - h_q, d_q, b, max_ctx_len, max_seq_len); + transformer_engine::fused_attn:: + reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_q.data_ptr()), + reinterpret_cast(q_buffer.data_ptr()), cu_new_lens.data_ptr(), + h_q, d_q, b, max_ctx_len, max_seq_len); } -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, + int d_q, int b, int max_ctx_len, int max_seq_len) { NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), "new_q and q_buffer must be of the same data type."); if (q_buffer.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { using dtype = at::Float8_e4m3fn; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { using dtype = at::Float8_e5m2; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } @@ -1076,16 +1072,18 @@ void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new **************************************************************************************************/ template -void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { - transformer_engine::fused_attn::reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(output.data_ptr()), - reinterpret_cast(output_buffer.data_ptr()), cu_new_lens.data_ptr(), - h_o, d_o, b, max_seq_len, is_output_right_aligned); +void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, + torch::Tensor cu_new_lens, int h_o, int d_o, int b, int max_seq_len, + bool is_output_right_aligned) { + transformer_engine::fused_attn:: + reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_buffer.data_ptr()), + cu_new_lens.data_ptr(), h_o, d_o, b, max_seq_len, is_output_right_aligned); } -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned) { +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { NVTE_CHECK(output.scalar_type() == output_buffer.scalar_type(), "output and output_buffer must be of the same data type."); if (output.scalar_type() == at::ScalarType::Half) { @@ -1136,18 +1134,21 @@ void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch:: if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) { if (is_non_paged) { - transformer_engine::fused_attn::reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), - cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); + transformer_engine::fused_attn:: + reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + page_table.data_ptr(), cu_new_lens.data_ptr(), + cu_cached_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); } - transformer_engine::fused_attn::copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_k.data_ptr()), - reinterpret_cast(new_v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), - cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, b, - max_ctx_len, max_seq_len, max_pages_per_seq); + transformer_engine::fused_attn:: + copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), + cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, + b, max_ctx_len, max_seq_len, max_pages_per_seq); } } diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index bf343226f1..7e6f6979da 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -9,13 +9,12 @@ namespace transformer_engine { namespace fused_attn { template -__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, - int h_q, int d_q, int b, - int max_ctx_len, int max_seq_len) { +__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, int h_q, + int d_q, int b, int max_ctx_len, int max_seq_len) { // new_q: thd; q_buffer: bshd; // cu_new_lens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = (cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]) * h_q * d_q; + int num_elts = (cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]) * h_q * d_q; int new_token_offset = cu_new_lens[batch_idx] * h_q * d_q; int cache_offset = batch_idx * max_seq_len * h_q * d_q; scalar_t *new_q_token = new_q + new_token_offset; @@ -27,12 +26,13 @@ __global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_ne } template -__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned) { +__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens, + int h_o, int d_o, int b, int max_seq_len, + bool is_output_right_aligned) { // output: bshd; output_buffer: thd; // cu_new_lens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int num_elts = new_len * h_o * d_o; int output_offset = batch_idx * max_seq_len * h_o * d_o; if (is_output_right_aligned) { @@ -49,8 +49,8 @@ __global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int template __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, - int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int d_v, - int b, int max_seq_len) { + int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, + int d_v, int b, int max_seq_len) { // k_cache, v_cache: bshd // batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1] int actual_b = b; @@ -60,10 +60,9 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in } } for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; - for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; - token_idx += gridDim.x) { + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; + for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) { int num_elts_k = h_kv * d_k; int num_elts_v = h_kv * d_v; int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; @@ -99,12 +98,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int *page_list = page_table + batch_idx * max_pages_per_seq; int new_token_offset = batch_idx * max_ctx_len; - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int page_idx = page_list[(cached_len - new_len + i) / page_size]; - int token_idx = - page_idx * page_size + (cached_len - new_len + i) % page_size; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (new_token_offset + i) * h_kv * d_k + j); @@ -118,12 +116,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int *page_list = page_table + batch_idx * max_pages_per_seq; - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int page_idx = page_list[(cached_len - new_len + i) / page_size]; - int token_idx = - page_idx * page_size + (cached_len - new_len + i) % page_size; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); } @@ -135,12 +132,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int *page_list = page_table + batch_idx * max_pages_per_seq; - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int page_idx = page_list[(cached_len - new_len + i) / page_size]; - int token_idx = - page_idx * page_size + (cached_len - new_len + i) % page_size; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j); diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index d9c63f9158..fbb343ed72 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -53,6 +53,7 @@ class InferenceParams: cache_manager: KVCacheManager, default = None Custom cache manager, with KVCacheManager as the base class. """ + def __init__( self, max_batch_size: int, @@ -254,8 +255,8 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): new_k_cache = new_k_cache[:actual_batch_size].contiguous() new_v_cache = new_v_cache[:actual_batch_size].contiguous() if qkv_format == "sbhd": - new_k_cache = new_k_cache.transpose(0,1) - new_v_cache = new_v_cache.transpose(0,1) + new_k_cache = new_k_cache.transpose(0, 1) + new_v_cache = new_v_cache.transpose(0, 1) if qkv_format == "thd": assert False, "UnfusedDotProductAttention does not support qkv_format=thd."