Skip to content

Merge release/2.6_ck_2 to release/2.6 #2012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,13 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)

cmake_dependent_option(
USE_CK_FLASH_ATTENTION
"Whether to build the CK flash_attention kernel. Will be enabled if USE_FLASH_ATTENTION is enabled."
ON
"USE_FLASH_ATTENTION"
OFF)

# We are currenlty not using alibi attention for Flash So we disable this
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
Expand All @@ -888,6 +895,13 @@ if(USE_ROCM)
endif()
endif()

# CK shared lib linkage
if(USE_ROCM)
if(UNIX AND (USE_CK_FLASH_ATTENTION))
include(cmake/External/ck.cmake)
endif()
endif()

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
28 changes: 13 additions & 15 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,20 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
if(USE_FLASH_ATTENTION)
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
if(USE_CK_FLASH_ATTENTION)
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
# disable buidling CK files
# add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
endif()
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
Expand Down
63 changes: 0 additions & 63 deletions aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt

This file was deleted.

4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,10 @@ if(USE_ROCM)
if(USE_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
endif()
# link CK library
if(USE_CK_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __ck_lib)
endif()
set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
# TODO: Not totally sure if this is live or not
Expand Down
43 changes: 43 additions & 0 deletions cmake/External/ck.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# create INTERFACE target for CK library
#

# get CK commit hash
execute_process(
COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel
RESULT_VARIABLE result
OUTPUT_VARIABLE submodule_status
ERROR_VARIABLE submodule_status_error
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(result EQUAL 0)
string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status})
# extract first 8 characters of the commit hash
string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash)
else()
message(FATAL_ERROR "Failed to get submodule status for composable_kernel.")
endif()

# get ROCm version from LoadHIP.cmake
include(${CMAKE_SOURCE_DIR}/cmake/public/LoadHIP.cmake)

# full path for CK library on compute-artifactory.amd.com
set(url "https://compute-artifactory.amd.com/artifactory/rocm-generic-local")
set(ck_lib_full_path "${url}/torch_ck_gen_lib/ck_${ck_commit_hash}/rocm_${ROCM_VERSION_DEV}/libck_kernels.so")

# set destination
set(destination "${CMAKE_SOURCE_DIR}/torch/lib/libck_kernels.so")

# download CK library
file(DOWNLOAD ${ck_lib_full_path} ${destination} SHOW_PROGRESS RESULT_VARIABLE download_status)
if(NOT download_status)
message(STATUS "Downloaded CK library successfully.")
else()
message(FATAL_ERROR "Failed to download the CK library from ${SOURCE_URL}.")
endif()

# create INTERFACE target
add_library(__ck_lib INTERFACE)

# specify path to CK library
target_link_libraries(__ck_lib INTERFACE ${destination})