Skip to content
9 changes: 1 addition & 8 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1373,13 +1373,6 @@ if(USE_ROCM)
set(ROCM_SOURCE_DIR "/opt/rocm")
endif()
message(INFO "caffe2 ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}")
target_include_directories(torch_hip PRIVATE
${ROCM_SOURCE_DIR}/include
${ROCM_SOURCE_DIR}/hcc/include
${ROCM_SOURCE_DIR}/rocblas/include
${ROCM_SOURCE_DIR}/hipsparse/include
${ROCM_SOURCE_DIR}/include/rccl/
)
if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE
USE_FLASH_ATTENTION
Expand Down Expand Up @@ -1713,7 +1706,7 @@ if(USE_ROCM)
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})

# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE})
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})
target_include_directories(torch_hip INTERFACE $<INSTALL_INTERFACE:include>)
endif()

Expand Down
23 changes: 17 additions & 6 deletions cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ else()
endif()
endif()

if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS})
set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include)
if(NOT DEFINED ENV{ROCM_INCLUDE_DIR})
set(ROCM_INCLUDE_DIR ${ROCM_PATH}/include)
else()
set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS})
set(ROCM_INCLUDE_DIR $ENV{ROCM_INCLUDE_DIR})
endif()

# MAGMA_HOME
Expand Down Expand Up @@ -72,6 +72,7 @@ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
macro(find_package_and_print_version PACKAGE_NAME)
find_package("${PACKAGE_NAME}" ${ARGN})
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR})
endmacro()

# Find the HIP Package
Expand All @@ -82,9 +83,16 @@ find_package_and_print_version(HIP 1.0 MODULE)
if(HIP_FOUND)
set(PYTORCH_FOUND_HIP TRUE)

find_package_and_print_version(hip REQUIRED CONFIG)
# Find ROCM version for checks. UNIX filename is rocm_version.h, Windows is hip_version.h
find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h hip_version.h
HINTS ${ROCM_INCLUDE_DIRS}/rocm-core ${ROCM_INCLUDE_DIRS}/hip /usr/include)
if(UNIX)
find_package_and_print_version(rocm-core REQUIRED CONFIG)
find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h
HINTS ${rocm_core_INCLUDE_DIR}/rocm-core)
else() # Win32
find_file(ROCM_VERSION_HEADER_PATH NAMES hip_version.h
HINTS ${hip_INCLUDE_DIR}/hip)
endif()
get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME)

if(EXISTS ${ROCM_VERSION_HEADER_PATH})
Expand Down Expand Up @@ -141,7 +149,6 @@ if(HIP_FOUND)
# Find ROCM components using Config mode
# These components will be searced for recursively in ${ROCM_PATH}
message("\n***** Library versions from cmake find_package *****\n")
find_package_and_print_version(hip REQUIRED CONFIG)
find_package_and_print_version(amd_comgr REQUIRED)
find_package_and_print_version(rocrand REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
Expand All @@ -168,7 +175,11 @@ if(HIP_FOUND)
if(UNIX)
find_package_and_print_version(rccl)
find_package_and_print_version(hsa-runtime64 REQUIRED)
endif()

list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)

if(UNIX)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)

Expand Down