File tree 1 file changed +15
-0
lines changed
1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -69,6 +69,21 @@ if(INTERN_BUILD_ATEN_OPS)
69
69
70
70
file (GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR} /../torchgen/*.py" )
71
71
72
+ # RowwiseScaled.cu requires sm90a flags
73
+ if (USE_CUDA)
74
+ set (ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR} /../aten/src/ATen/native/cuda/RowwiseScaledMM.cu" )
75
+
76
+ # Get existing arch flags
77
+ torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
78
+
79
+ # Check NVCC version and existing arch flags
80
+ if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND
81
+ EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*" )
82
+ set_source_files_properties (${ROWWISE_SCALED_MM_FILE}
83
+ PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a" )
84
+ endif ()
85
+ endif ()
86
+
72
87
set (GEN_ROCM_FLAG)
73
88
if (USE_ROCM)
74
89
set (GEN_ROCM_FLAG --rocm)
You can’t perform that action at this time.
0 commit comments