@@ -280,16 +280,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
280280 FragS& s0, float * c4, float * c5, float * c6,
281281 float * c7, FragS& s1) {
282282 #ifdef USE_ROCM
283- // AMD implementation - fixed
284- *c0 = __ocml_fmul_f32 ( *c0, __half2float (s0[0 ].x ) );
285- *c1 = __ocml_fmul_f32 ( *c1, __half2float (s0[0 ].y ) );
286- *c2 = __ocml_fmul_f32 ( *c2, __half2float (s0[1 ].x ) );
287- *c3 = __ocml_fmul_f32 ( *c3, __half2float (s0[1 ].y ) );
283+ // AMD MI300X implementation
284+ *c0 = *c0 * __half2float (s0[0 ].x );
285+ *c1 = *c1 * __half2float (s0[0 ].y );
286+ *c2 = *c2 * __half2float (s0[1 ].x );
287+ *c3 = *c3 * __half2float (s0[1 ].y );
288288
289- *c4 = __ocml_fmul_f32 ( *c4, __half2float (s1[0 ].x ) );
290- *c5 = __ocml_fmul_f32 ( *c5, __half2float (s1[0 ].y ) );
291- *c6 = __ocml_fmul_f32 ( *c6, __half2float (s1[1 ].x ) );
292- *c7 = __ocml_fmul_f32 ( *c7, __half2float (s1[1 ].y ));
289+ *c4 = *c4 * __half2float (s1[0 ].x );
290+ *c5 = *c5 * __half2float (s1[0 ].y );
291+ *c6 = *c6 * __half2float (s1[1 ].x );
292+ *c7 = *c7 * __half2float (s1[1 ].y );
293293 #else
294294 // NVIDIA implementation
295295 *c0 = __fmul_rn (*c0, __half2float (s0[0 ].x ));
0 commit comments