Skip to content

[release/2.6] Integrate CK prebuilt library into Pytorch build #2031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,13 @@ if(USE_ROCM)
endif()
endif()

# link CK library
if(USE_ROCM)
if(UNIX AND USE_CK_FLASH_ATTENTION)
include(cmake/External/ck_kernels.cmake)
endif()
endif()

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
30 changes: 15 additions & 15 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,23 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
if(USE_FLASH_ATTENTION)
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
if(USE_CK_FLASH_ATTENTION)
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
if(DEFINED CK_KERNELS_INSTALL_FROM_SOURCE)
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
endif()
endif()
# building CK kernels from source
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
endif() # end of CK_KERNELS_INSTALL_FROM_SOURCE
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
endif() # end of USE_CK_FLASH_ATTENTION
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
endif()
Expand Down
11 changes: 5 additions & 6 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
at::ROCmFABackend::Ck) {

#if defined(USE_CK_FLASH_ATTENTION)
TORCH_WARN_ONCE("Using CK backend for Efficient Attention forward...");

std::optional<Tensor> out(res);
std::optional<Tensor> seqused_k = std::nullopt;
std::optional<Tensor> alibi_slopes = std::nullopt;
Expand Down Expand Up @@ -1171,12 +1173,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
TORCH_CHECK(false, "Attempting to use CK mem_eff_forward backend in a build that has not built CK");
#endif
} else { // use aotriton
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
" (gfx90a/gfx942/gfx1100/gfx1201)")
}
pytorch_flash::check_aotriton_gpu_arch(stream);

TORCH_WARN_ONCE("Using AOTriton backend for Efficient Attention forward...");

// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ _efficient_attention_backward(
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck)
{
#if defined(USE_CK_FLASH_ATTENTION)
TORCH_WARN_ONCE("Using CK backend for Efficient Attention backward...");

const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float();
// Store grad_bias in optional
std::optional<at::Tensor> opt_grad_bias = grad_bias;
Expand Down Expand Up @@ -454,12 +456,10 @@ _efficient_attention_backward(
"ROCM does not support num_split_keys in _efficient_attention_forward");
TORCH_CHECK(!window_size.has_value(),
"ROCM does not support window_size in _efficient_attention_forward");
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
" (gfx90a/gfx942/gfx1100/gfx1201)")
}
pytorch_flash::check_aotriton_gpu_arch(stream);

TORCH_WARN_ONCE("Using AOTriton backend for Efficient Attention backward...");

const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@

namespace pytorch_flash {

namespace {

void check_gpu_arch(hipStream_t stream) {
void check_aotriton_gpu_arch(hipStream_t stream) {
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
Expand All @@ -82,6 +80,8 @@ void check_gpu_arch(hipStream_t stream) {
}
}

namespace {

// We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will
// be used in the backward function
Expand Down Expand Up @@ -133,7 +133,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
const bool return_softmax,
std::optional<at::Generator> gen_) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
Expand Down Expand Up @@ -275,7 +275,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot

at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
Expand Down Expand Up @@ -444,7 +444,7 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);

bool is_dropout = p_dropout > 0.0;

Expand Down Expand Up @@ -631,7 +631,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);


bool is_dropout = p_dropout > 0.0;

Expand Down
18 changes: 17 additions & 1 deletion aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

namespace pytorch_flash {

// check AOTriton GPU support
void check_aotriton_gpu_arch(hipStream_t stream);

// AOTriton Implementation
TORCH_API
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand Down Expand Up @@ -228,6 +231,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
TORCH_WARN_ONCE("Using CK backend for Flash Attention forward...");

std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
Expand All @@ -243,6 +248,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
gen_,
dummy_attn_bias); // Not used in flash attention
} else {
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention forward...");
return mha_fwd_aot(q,
k,
v,
Expand All @@ -258,7 +264,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head

}
#else
return mha_fwd_aot(q,
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention forward...");
return mha_fwd_aot(q,
k,
v,
out_,
Expand Down Expand Up @@ -296,6 +303,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
TORCH_WARN_ONCE("Using CK backend for Flash Attention varlen forward...");
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_varlen_fwd_ck(
q,
Expand All @@ -317,6 +325,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
gen_,
dummy_attn_bias); // Not used in flash attention
} else {
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen forward...");
return mha_varlen_fwd_aot(q,
k,
v,
Expand All @@ -338,6 +347,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
gen_);
}
#else
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen forward...");
return mha_varlen_fwd_aot(q,
k,
v,
Expand Down Expand Up @@ -384,6 +394,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
TORCH_WARN_ONCE("Using CK backend for Flash Attention backward...");
std::optional<at::Tensor> non_null_dbias = std::nullopt;
auto[dQuery,
dKey,
Expand Down Expand Up @@ -413,6 +424,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
// for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention backward...");
return mha_bwd_aot(dout,
q,
k,
Expand All @@ -437,6 +449,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
at::ROCmFABackend::Ck) {
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
}
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention backward...");
return mha_bwd_aot(
dout,
q,
Expand Down Expand Up @@ -487,6 +500,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
TORCH_WARN_ONCE("Using CK backend for Flash Attention varlen backward...");
std::optional<at::Tensor> non_null_dbias = std::nullopt;
auto[dQuery,
dKey,
Expand Down Expand Up @@ -521,6 +535,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen backward...");
return mha_varlen_bwd_aot(dout,
q,
k,
Expand All @@ -546,6 +561,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
philox_offset);
}
#else
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen backward...");
return mha_varlen_bwd_aot(dout,
q,
k,
Expand Down
4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,10 @@ if(USE_ROCM)
if(USE_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
endif()
# link CK library if not building CK_kernels from source
if(USE_CK_FLASH_ATTENTION AND NOT CK_KERNELS_INSTALL_FROM_SOURCE)
target_link_libraries(torch_hip PRIVATE __ck_kernels_lib)
endif()
set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
# TODO: Not totally sure if this is live or not
Expand Down
63 changes: 63 additions & 0 deletions cmake/External/ck_kernels.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#
# create INTERFACE target for CK library
#
if(NOT __ck_kernels_included)
set(__ck_kernels_included TRUE)

set(ck_kernels_install_dir "${PROJECT_SOURCE_DIR}/torch/lib")

set(__ck_kernels_version 0.1)

# create INTERFACE target
add_library(__ck_kernels_lib INTERFACE)

if(DEFINED ENV{CK_KERNELS_INSTALLED_PREFIX})
# Copy .so from $ENV{CK_KERNELS_INSTALLED_PREFIX} into ${ck_kernels_install_dir}
install(DIRECTORY
$ENV{CK_KERNELS_INSTALLED_PREFIX}/
DESTINATION ${ck_kernels_install_dir}
)
set(ck_kernels_install_path "$ENV{CK_KERNELS_INSTALLED_PREFIX}/libck_kernels.so")
# specify path to CK library
target_link_libraries(__ck_kernels_lib INTERFACE ${ck_kernels_install_path})
message(STATUS "Using Preinstalled CK_kernels from $ENV{CK_KERNELS_INSTALLED_PREFIX}; installed at ${ck_kernels_install_dir}")
elseif(DEFINED ENV{CK_KERNELS_PACKAGE_BASE_URL})
# get CK commit hash
execute_process(
COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel
RESULT_VARIABLE result
OUTPUT_VARIABLE submodule_status
ERROR_VARIABLE submodule_status_error
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(result EQUAL 0)
string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status})
# extract first 8 characters of the commit hash
string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash)
else()
message(FATAL_ERROR "Failed to get submodule status for composable_kernel.")
endif()

set(ck_kernels_package_full_url "$ENV{CK_KERNELS_PACKAGE_BASE_URL}/torch_ck_gen_lib/ck_${ck_commit_hash}/rocm_${ROCM_VERSION_DEV}/libck_kernels.tar.gz")
set(ck_kernels_install_path "${ck_kernels_install_dir}/libck_kernels.so")

ExternalProject_Add(ck_kernels_external
URL "${ck_kernels_package_full_url}"
# URL_HASH
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/ck_kernels_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/ck_kernels_tarball"
"${ck_kernels_install_dir}"
INSTALL_BYPRODUCTS "${ck_kernels_install_path}"
)
add_dependencies(__ck_kernels_lib ck_kernels_external)
message(STATUS "Using CK_kernels from pre-compiled binary ${ck_kernels_package_full_url}; installed at ${ck_kernels_install_dir}")
# specify path to CK library
target_link_libraries(__ck_kernels_lib INTERFACE ${ck_kernels_install_path})
else()
set(CK_KERNELS_INSTALL_FROM_SOURCE TRUE)
endif() # DEFINED ENV{CK_KERNELS_INSTALLED_PREFIX}

endif() # __ck_kernels_included