@@ -280,16 +280,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
280
280
FragS& s0, float * c4, float * c5, float * c6,
281
281
float * c7, FragS& s1) {
282
282
#ifdef USE_ROCM
283
- // AMD implementation - fixed
284
- *c0 = __ocml_fmul_f32 ( *c0, __half2float (s0[0 ].x ) );
285
- *c1 = __ocml_fmul_f32 ( *c1, __half2float (s0[0 ].y ) );
286
- *c2 = __ocml_fmul_f32 ( *c2, __half2float (s0[1 ].x ) );
287
- *c3 = __ocml_fmul_f32 ( *c3, __half2float (s0[1 ].y ) );
283
+ // AMD MI300X implementation
284
+ *c0 = *c0 * __half2float (s0[0 ].x );
285
+ *c1 = *c1 * __half2float (s0[0 ].y );
286
+ *c2 = *c2 * __half2float (s0[1 ].x );
287
+ *c3 = *c3 * __half2float (s0[1 ].y );
288
288
289
- *c4 = __ocml_fmul_f32 ( *c4, __half2float (s1[0 ].x ) );
290
- *c5 = __ocml_fmul_f32 ( *c5, __half2float (s1[0 ].y ) );
291
- *c6 = __ocml_fmul_f32 ( *c6, __half2float (s1[1 ].x ) );
292
- *c7 = __ocml_fmul_f32 ( *c7, __half2float (s1[1 ].y ));
289
+ *c4 = *c4 * __half2float (s1[0 ].x );
290
+ *c5 = *c5 * __half2float (s1[0 ].y );
291
+ *c6 = *c6 * __half2float (s1[1 ].x );
292
+ *c7 = *c7 * __half2float (s1[1 ].y );
293
293
#else
294
294
// NVIDIA implementation
295
295
*c0 = __fmul_rn (*c0, __half2float (s0[0 ].x ));
0 commit comments