Skip to content

Download CK library from compute-artifactory and link to Pytorch #2007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,13 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)

cmake_dependent_option(
USE_CK_FLASH_ATTENTION
"Whether to build the CK flash_attention kernel. Will be enabled if USE_FLASH_ATTENTION is enabled."
ON
"USE_FLASH_ATTENTION"
OFF)

# We are currenlty not using alibi attention for Flash So we disable this
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
Expand All @@ -888,6 +895,13 @@ if(USE_ROCM)
endif()
endif()

# CK shared lib linkage
if(USE_ROCM)
if(UNIX AND (USE_CK_FLASH_ATTENTION))
include(cmake/External/ck.cmake)
endif()
endif()

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ if(USE_FLASH_ATTENTION)
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
# disable buidling CK files
# add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
endif()
Expand Down
63 changes: 0 additions & 63 deletions aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt

This file was deleted.

1 change: 1 addition & 0 deletions benchmarks/transformer/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def print_results(experiments: List[Experiment]):


def main():
torch.backends.cuda.preferred_rocm_fa_library("ck")
seed = 123
torch.manual_seed(seed)
results = []
Expand Down
4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,10 @@ if(USE_ROCM)
if(USE_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
endif()
# link CK library
if(USE_CK_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __ck_lib)
endif()
set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
# TODO: Not totally sure if this is live or not
Expand Down
44 changes: 44 additions & 0 deletions cmake/External/ck.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# create INTERFACE target for CK library
#

# get CK commit hash
execute_process(
COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel
RESULT_VARIABLE result
OUTPUT_VARIABLE submodule_status
ERROR_VARIABLE submodule_status_error
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(result EQUAL 0)
string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status})
# extract first 8 characters of the commit hash
string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash)
else()
message(FATAL_ERROR "Failed to get submodule status for composable_kernel.")
endif()

# get ROCm version from LoadHIP.cmake
include(${CMAKE_SOURCE_DIR}/cmake/public/LoadHIP.cmake)

# full path for CK library on compute-artifactory.amd.com
set(url "https://compute-artifactory.amd.com/artifactory/rocm-generic-local")
set(ck_lib_full_path "${url}/torch_ck_gen_lib/ck_${ck_commit_hash}/rocm_${ROCM_VERSION_DEV}/libck_kernels.so")

# set destination
set(destination "${CMAKE_SOURCE_DIR}/torch/lib/libck_kernels.so")

# download CK library
file(DOWNLOAD ${ck_lib_full_path} ${destination} SHOW_PROGRESS RESULT_VARIABLE download_status)
if(NOT download_status)
message(STATUS "Downloaded CK library successfully.")
else()
message(FATAL_ERROR "Failed to download the CK library from ${SOURCE_URL}.")
endif()

# create INTERFACE target
add_library(__ck_lib INTERFACE)

# specify path to CK library
target_link_libraries(__ck_lib INTERFACE ${destination})