Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 23, 2025
1 parent 9ec3649 commit 8f8a81e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);

/*! \brief Compute the backward of the dot product attention with packed KV input.
*
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ std::vector<py::object> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens,
int h_q, int d_q, int b, int max_ctx_len, int max_seq_len);
void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o,
int d_o, int b, int max_seq_len, bool is_output_right_aligned);
void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q,
int d_q, int b, int max_ctx_len, int max_seq_len);
void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens,
int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k,
Expand Down
73 changes: 37 additions & 36 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "thd_utils.cuh"
#include "kv_cache.cuh"
#include "thd_utils.cuh"

constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
Expand Down Expand Up @@ -1036,36 +1036,32 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
template <typename scalar_t>
void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens,
int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) {
transformer_engine::fused_attn::reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_q.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(q_buffer.data_ptr<scalar_t>()), cu_new_lens.data_ptr<int>(),
h_q, d_q, b, max_ctx_len, max_seq_len);
transformer_engine::fused_attn::
reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_q.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(q_buffer.data_ptr<scalar_t>()), cu_new_lens.data_ptr<int>(),
h_q, d_q, b, max_ctx_len, max_seq_len);
}

void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens,
int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) {
void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q,
int d_q, int b, int max_ctx_len, int max_seq_len) {
NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(),
"new_q and q_buffer must be of the same data type.");
if (q_buffer.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
max_seq_len);
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
max_seq_len);
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::Float) {
using dtype = float;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
max_seq_len);
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
max_seq_len);
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
max_seq_len);
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, max_seq_len);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
Expand All @@ -1076,16 +1072,18 @@ void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new
**************************************************************************************************/

template <typename scalar_t>
void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens,
int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) {
transformer_engine::fused_attn::reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(output.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(output_buffer.data_ptr<scalar_t>()), cu_new_lens.data_ptr<int>(),
h_o, d_o, b, max_seq_len, is_output_right_aligned);
void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer,
torch::Tensor cu_new_lens, int h_o, int d_o, int b, int max_seq_len,
bool is_output_right_aligned) {
transformer_engine::fused_attn::
reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(output.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(output_buffer.data_ptr<scalar_t>()),
cu_new_lens.data_ptr<int>(), h_o, d_o, b, max_seq_len, is_output_right_aligned);
}

void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o,
int d_o, int b, int max_seq_len, bool is_output_right_aligned) {
void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens,
int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) {
NVTE_CHECK(output.scalar_type() == output_buffer.scalar_type(),
"output and output_buffer must be of the same data type.");
if (output.scalar_type() == at::ScalarType::Half) {
Expand Down Expand Up @@ -1136,18 +1134,21 @@ void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch::
if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr &&
v_cache.data_ptr() != nullptr) {
if (is_non_paged) {
transformer_engine::fused_attn::reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
transformer_engine::fused_attn::
reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()),
page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(),
cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
}
transformer_engine::fused_attn::copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq);
transformer_engine::fused_attn::
copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), qkv_format, h_kv, d_k, d_v,
b, max_ctx_len, max_seq_len, max_pages_per_seq);
}
}

Expand Down
46 changes: 21 additions & 25 deletions transformer_engine/pytorch/csrc/kv_cache.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens,
int h_q, int d_q, int b,
int max_ctx_len, int max_seq_len) {
__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, int h_q,
int d_q, int b, int max_ctx_len, int max_seq_len) {
// new_q: thd; q_buffer: bshd;
// cu_new_lens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]) * h_q * d_q;
int num_elts = (cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]) * h_q * d_q;
int new_token_offset = cu_new_lens[batch_idx] * h_q * d_q;
int cache_offset = batch_idx * max_seq_len * h_q * d_q;
scalar_t *new_q_token = new_q + new_token_offset;
Expand All @@ -27,12 +26,13 @@ __global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_ne
}

template <typename scalar_t>
__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens, int h_o,
int d_o, int b, int max_seq_len, bool is_output_right_aligned) {
__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens,
int h_o, int d_o, int b, int max_seq_len,
bool is_output_right_aligned) {
// output: bshd; output_buffer: thd;
// cu_new_lens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
int num_elts = new_len * h_o * d_o;
int output_offset = batch_idx * max_seq_len * h_o * d_o;
if (is_output_right_aligned) {
Expand All @@ -49,8 +49,8 @@ __global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int

template <typename scalar_t>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int d_v,
int b, int max_seq_len) {
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
// batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1]
int actual_b = b;
Expand All @@ -60,10 +60,9 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
}
}
for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) {
int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx];
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len;
token_idx += gridDim.x) {
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) {
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k;
Expand Down Expand Up @@ -99,12 +98,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = page_table + batch_idx * max_pages_per_seq;
int new_token_offset = batch_idx * max_ctx_len;
int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx];
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = page_list[(cached_len - new_len + i) / page_size];
int token_idx =
page_idx * page_size + (cached_len - new_len + i) % page_size;
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (new_token_offset + i) * h_kv * d_k + j);
Expand All @@ -118,12 +116,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx];
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = page_list[(cached_len - new_len + i) / page_size];
int token_idx =
page_idx * page_size + (cached_len - new_len + i) % page_size;
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j);
}
Expand All @@ -135,12 +132,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
} else if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx];
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = page_list[(cached_len - new_len + i) / page_size];
int token_idx =
page_idx * page_size + (cached_len - new_len + i) % page_size;
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j);
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class InferenceParams:
cache_manager: KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class.
"""

def __init__(
self,
max_batch_size: int,
Expand Down Expand Up @@ -254,8 +255,8 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str):
new_k_cache = new_k_cache[:actual_batch_size].contiguous()
new_v_cache = new_v_cache[:actual_batch_size].contiguous()
if qkv_format == "sbhd":
new_k_cache = new_k_cache.transpose(0,1)
new_v_cache = new_v_cache.transpose(0,1)
new_k_cache = new_k_cache.transpose(0, 1)
new_v_cache = new_v_cache.transpose(0, 1)
if qkv_format == "thd":
assert False, "UnfusedDotProductAttention does not support qkv_format=thd."

Expand Down

0 comments on commit 8f8a81e

Please sign in to comment.