@@ -1373,12 +1373,62 @@ if(USE_ROCM)
1373
1373
set (ROCM_SOURCE_DIR "/opt/rocm" )
1374
1374
endif ()
1375
1375
message (INFO "caffe2 ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR} " )
1376
+ find_package (rocthrust )
1377
+ if (rocthrust_FOUND )
1378
+ message (STATUS "rocthrust found" )
1379
+ else () #If rocthrust not found
1380
+ message (FATAL_ERROR "rocthrust not found !!! Install rocthrust to proceed ..." )
1381
+ endif (rocthrust_FOUND )
1382
+
1383
+ find_package (aotriton )
1384
+ if (aotriton_FOUND )
1385
+ message (STATUS "aotriton found" )
1386
+ set (AOTRITON_INCLUDE_DIR ${INTERFACE_INCLUDE_DIRECTORIES} )
1387
+ else () #If aotriton not found
1388
+ message (FATAL_ERROR "aotriton not found !!! Install aotriton to proceed ..." )
1389
+ endif (aotriton_FOUND )
1390
+
1391
+ find_package (rocprim )
1392
+ if (rocprim_FOUND )
1393
+ message (STATUS "rocprim found" )
1394
+ else () #If rocprim not found
1395
+ message (FATAL_ERROR "rocprim not found !!! Install rocprim to proceed ..." )
1396
+ endif (rocprim_FOUND )
1397
+
1398
+ find_package (hipcub )
1399
+ if (hipcub_FOUND )
1400
+ message (STATUS "hipcub found" )
1401
+ else () #If hipcub not found
1402
+ message (FATAL_ERROR "hipcub not found !!! Install hipcub to proceed ..." )
1403
+ endif (hipcub_FOUND )
1404
+
1405
+ find_package (rocrand )
1406
+ if (rocrand_FOUND )
1407
+ message (STATUS "rocrand found" )
1408
+ else () #If rocrand not found
1409
+ message (FATAL_ERROR "rocrand not found !!! Install rocrand to proceed ..." )
1410
+ endif (rocrand_FOUND )
1411
+
1412
+ find_package (composable_kernel )
1413
+ if (composable_kernel_FOUND )
1414
+ message (STATUS "composable-kernel found" )
1415
+ set (CK_INCLUDE_DIR ${INTERFACE_INCLUDE_DIRECTORIES} )
1416
+ else () #If composable-kernel not found
1417
+ message (FATAL_ERROR "composable-kernel not found !!! Install composable-kernel to proceed ..." )
1418
+ endif (composable-kernel_FOUND )
1419
+
1376
1420
target_include_directories (torch_hip PRIVATE
1377
1421
${ROCM_SOURCE_DIR} /include
1378
1422
${ROCM_SOURCE_DIR} /hcc/include
1379
1423
${ROCM_SOURCE_DIR} /rocblas/include
1380
1424
${ROCM_SOURCE_DIR} /hipsparse/include
1381
1425
${ROCM_SOURCE_DIR} /include/rccl/
1426
+ ${AOTRITON_INCLUDE_DIR}
1427
+ ${ROCTHRUST_INCLUDE_DIR}
1428
+ ${ROCPRIM_INCLUDE_DIR}
1429
+ ${ROCRAND_INCLUDE_DIR}
1430
+ ${HIPCUB_INCLUDE_DIR}
1431
+ ${CK_INCLUDE_DIR}
1382
1432
)
1383
1433
if (USE_FLASH_ATTENTION )
1384
1434
target_compile_definitions (torch_hip PRIVATE
0 commit comments