Skip to content

Commit 6f43e01

Browse files
committed
Simplify ROCm float multiplication in sparse Marlin MMA
Replace __ocml_fmul_f32 with standard C++ multiplication for more readable and straightforward float scaling on AMD MI300X GPUs.
1 parent dc53980 commit 6f43e01

File tree

1 file changed

+9
-9
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+9
-9
lines changed

torchao/csrc/cuda/sparse_marlin/mma.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,16 +280,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
280280
FragS& s0, float* c4, float* c5, float* c6,
281281
float* c7, FragS& s1) {
282282
#ifdef USE_ROCM
283-
// AMD implementation - fixed
284-
*c0 = __ocml_fmul_f32(*c0, __half2float(s0[0].x));
285-
*c1 = __ocml_fmul_f32(*c1, __half2float(s0[0].y));
286-
*c2 = __ocml_fmul_f32(*c2, __half2float(s0[1].x));
287-
*c3 = __ocml_fmul_f32(*c3, __half2float(s0[1].y));
283+
// AMD MI300X implementation
284+
*c0 = *c0 * __half2float(s0[0].x);
285+
*c1 = *c1 * __half2float(s0[0].y);
286+
*c2 = *c2 * __half2float(s0[1].x);
287+
*c3 = *c3 * __half2float(s0[1].y);
288288

289-
*c4 = __ocml_fmul_f32(*c4, __half2float(s1[0].x));
290-
*c5 = __ocml_fmul_f32(*c5, __half2float(s1[0].y));
291-
*c6 = __ocml_fmul_f32(*c6, __half2float(s1[1].x));
292-
*c7 = __ocml_fmul_f32(*c7, __half2float(s1[1].y));
289+
*c4 = *c4 * __half2float(s1[0].x);
290+
*c5 = *c5 * __half2float(s1[0].y);
291+
*c6 = *c6 * __half2float(s1[1].x);
292+
*c7 = *c7 * __half2float(s1[1].y);
293293
#else
294294
// NVIDIA implementation
295295
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));

0 commit comments

Comments
 (0)