-
Couldn't load subscription status.
- Fork 353
[wip] Rocm sparse fix #1868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[wip] Rocm sparse fix #1868
Conversation
Update GPU architecture check to use gcnArchName and improve detection of gfx942 support
Reorganize source file selection logic for CUDA and ROCm builds, improving conditional handling of GPU sources and CUTLASS kernels. Simplify the source file selection process and improve readability of the build configuration.
Modify CUTLASS kernel configuration to explicitly check for non-ROCm platforms when enabling support, ensuring more precise build configuration for different GPU environments.
Move source file collection logic to maintain consistent code organization and improve readability of the build configuration. No functional changes were made to the source file selection process.
Remove the `-t=0` flag from NVCC compilation options, which appears to be unnecessary. This simplifies the compilation configuration without impacting build behavior.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1868
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 4 Unrelated FailuresAs of commit 479cc1d with merge base 34421b1 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Add conditional compilation for ROCm platforms in the sparse Marlin matrix multiply accumulate (MMA) function. This ensures proper inline assembly implementation for both CUDA and ROCm environments, using platform-specific register and instruction handling.
Use __builtin_bit_cast to correctly convert float pairs to half-precision uint32_t values for AMD GPU platforms, ensuring proper type handling in the sparse Marlin matrix multiply accumulate (MMA) implementation.
Update CUDA half-precision operations using __hsub2 and __hfma2 intrinsics to improve performance and precision in sparse matrix multiply-accumulate (MMA) computations.
Update AMD GPU implementation to use __hsub2 and __hmul2 intrinsics for improved performance and precision in half-precision sparse matrix multiply-accumulate computations.
Update AMD GPU implementation to use __builtin_amdgcn_fmul_f32 instead of __builtin_amdgcn_fmul_legacy for more accurate float multiplication in the scale_floats function.
Include necessary ROCm-specific headers for HIP runtime and half-precision operations, with comments addressing potential compiler and architecture considerations for AMD GPU platforms.
Replace __builtin_amdgcn_fmul_f32 with __ocml_fmul_f32 for more accurate and consistent float multiplication in the scale_floats function on AMD GPU platforms.
Replace __builtin_amdgcn_global_load_lds with inline assembly using ds_load_b instruction for more precise and direct global to local data store (LDS) transfer on MI300X AMD GPUs.
Replace __ocml_fmul_f32 with standard C++ multiplication for more readable and straightforward float scaling on AMD MI300X GPUs.
Update cudaFuncSetAttribute call to use reinterpret_cast for correct function pointer handling in the Marlin_24 CUDA kernel, ensuring proper dynamic shared memory configuration.
Refactor cp_async4 functions for ROCm to use explicit ds_load instructions for 4, 8, and 16-byte transfers. Add a fallback mechanism using __builtin_memcpy for unsupported sizes, improving the precision and flexibility of global to local data store (LDS) transfers on MI300X AMD GPUs.
Add missing closing braces in cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions to ensure proper code structure and prevent potential compilation issues in the ROCm sparse Marlin MMA implementation.
…functions Simplify ROCm global to LDS transfer by removing fallback __builtin_memcpy in cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions, reducing code complexity while maintaining the primary ds_load_b128 transfer mechanism.
…functions Simplify ROCm global to LDS transfer by removing the 16-byte ds_load_b128 instruction from cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions, further reducing code complexity and maintaining the core transfer mechanism.
…tion Replace global_load_dwordx4 with multiple ds_read_b32 instructions for better compatibility and support across different ROCm platforms. Modify ldsm4 and ldsm4_t functions to use more widely supported memory load techniques.
Update ldsm4_m device function to use separate ds_read_b32 instructions instead of a single ds_read_b64, improving compatibility and load behavior on ROCm platforms.
Modify the MFMA instruction assembly for AMD GPUs to use correct syntax and operand handling. Replace register constraints with vector register constraints and simplify the instruction format to improve compatibility and readability on ROCm platforms.
TLDR: fix sparse marlin kernel for rocm
This pull request includes several updates to the
setup.pyfile and modifications to CUDA and ROCm specific code in thetorchao/csrc/cuda/sparse_marlindirectory. The most important changes focus on improving compatibility with ROCm, updating assertions, and refining the build process for CUDA and ROCm extensions.Updates to
setup.py:-t=0flag from thenvcccompile arguments.Modifications to CUDA and ROCm specific code:
cp_async4_pred_zfill,cp_async4_pred, andcp_async4functions to use theLDS.Ginstruction for global to LDS transfers on MI300X. [1] [2] [3]mma.hto support ROCm architecture and improve performance on MI300X. [1] [2] [3] [4]to_half4,dequant_4bit,dequant_8bit,scale, andscale_floatsto use appropriate ROCm intrinsics and improve compatibility. [1] [2] [3] [4] [5]