Skip to content

Commit

Permalink
WIP: some cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Feb 23, 2025
1 parent a391a49 commit 9ec3649
Show file tree
Hide file tree
Showing 15 changed files with 570 additions and 615 deletions.
3 changes: 1 addition & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
Expand All @@ -23,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py
3 changes: 1 addition & 2 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ target_include_directories(transformer_engine PUBLIC
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
CUDA::cudart)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
Expand Down
20 changes: 11 additions & 9 deletions transformer_engine/common/include/transformer_engine/fused_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,17 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);

/*! \brief Get Q format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return q format, e.g. sbhd.
*/
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout);

/*! \brief Get KV format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbh3d.
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return kv format, e.g. sbhd.
* \return kv format, e.g. bshd.
*/
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);

Expand Down Expand Up @@ -367,14 +367,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);

/*! \brief Compute the backward of the dot product attention with packed KV input.
*
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# See LICENSE for license information.

"""Attention."""
"""Attention"""
import collections
from contextlib import nullcontext
from importlib.metadata import version as get_pkg_version
Expand Down Expand Up @@ -363,7 +363,7 @@ def __eq__(self, other):
}


__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
__all__ = ["DotProductAttention", "MultiheadAttention"]


def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
Expand Down
26 changes: 0 additions & 26 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,32 +276,6 @@ def fused_attn_fwd(

# execute kernel

# print(max_seqlen_q,
# max_seqlen_kv,
# is_training,
# attn_scale,
# dropout,
# fast_zero_fill,
# QKVLayout[qkv_layout],
# AttnBiasType[attn_bias_type],
# AttnMaskType[attn_mask_type],
# window_size,
# cu_seqlens_q,
# cu_seqlens_kv,
# q.shape,
# k.shape,
# v.shape,
# fake_dtype,
# cu_seqlens_q_padded,
# cu_seqlens_kv_padded,
# page_table_k,
# page_table_v,
# s_quantizer,
# o_quantizer,
# attn_bias,
# rng_gen,
# rng_elts_per_thread,
# )
output_tensors = tex.fused_attn_fwd(
max_seqlen_q,
max_seqlen_kv,
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_bf16.h>
Expand Down
22 changes: 10 additions & 12 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,6 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
/***************************************************************************************************
* Attention
**************************************************************************************************/
void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor step_lens,
NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len,
int max_seq_len);

void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, int h_o,
int d_o, int b, int max_seq_len, bool is_output_right_aligned);

void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor step_lens,
torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k,
int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged);

NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
Expand Down Expand Up @@ -82,6 +70,16 @@ std::vector<py::object> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens,
int h_q, int d_q, int b, int max_ctx_len, int max_seq_len);
void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o,
int d_o, int b, int max_seq_len, bool is_output_right_aligned);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k,
int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged);

/***************************************************************************************************
* GEMM
**************************************************************************************************/
Expand Down
Loading

0 comments on commit 9ec3649

Please sign in to comment.