Skip to content

Commit 8b00b06

Browse files
committed
Add cross-platform support for Marlin kernel attribute setting
Modify the Marlin kernel to support both CUDA and ROCm platforms by: - Introducing platform-specific function attribute setting - Using conditional compilation with #ifdef __HIP_PLATFORM_AMD__ - Refactoring kernel attribute configuration for better cross-platform compatibility
1 parent 04014e7 commit 8b00b06

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,10 +861,28 @@ __global__ void Marlin_24(
861861
thread_n_blocks == THREAD_N_BLOCKS && \
862862
thread_k_blocks == THREAD_K_BLOCKS && \
863863
group_blocks == GROUP_BLOCKS) { \
864-
cudaFuncSetAttribute( \
865-
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
866-
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
867-
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
864+
/* Define the kernel type */ \
865+
using KernelType = decltype(Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, \
866+
THREAD_M_BLOCKS, THREAD_K_BLOCKS, \
867+
STAGES, GROUP_BLOCKS>); \
868+
/* Get function pointer for the kernel */ \
869+
const void* kernel_ptr = reinterpret_cast<const void*>( \
870+
&Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
871+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>); \
872+
\
873+
/* Set attribute based on platform */ \
874+
#ifdef __HIP_PLATFORM_AMD__ \
875+
hipFuncSetAttribute(kernel_ptr, \
876+
hipFuncAttributeMaxDynamicSharedMemorySize, \
877+
max_shared_mem); \
878+
#else \
879+
cudaFuncSetAttribute( \
880+
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
881+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
882+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
883+
#endif \
884+
\
885+
/* Launch kernel */ \
868886
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
869887
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
870888
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \

0 commit comments

Comments
 (0)