@@ -550,38 +550,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
550
550
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
551
551
}
552
552
553
- if (sourceElem.isFloat8E5M2FNUZ () && destElem.isF32 () && chipset >= kGfx940 ) {
553
+ if (destElem.isF32 () &&
554
+ ((sourceElem.isFloat8E5M2FNUZ () && chipset >= kGfx940 ) ||
555
+ (sourceElem.isFloat8E5M2 () && chipset.hasOcpFp8 ()))) {
554
556
// Known to be correct because there are no scalar f8 instructions and
555
557
// because a length mismatch will have been caught by the verifier.
556
558
Type sourceBElem =
557
559
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
558
560
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
559
- if (sourceBElem.isFloat8E5M2FNUZ ())
561
+ if (sourceBElem.isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 () )
560
562
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
561
- if (sourceBElem.isFloat8E4M3FNUZ ())
563
+ if (sourceBElem.isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN () )
562
564
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
563
565
}
564
566
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
565
- if (sourceBElem.isFloat8E5M2FNUZ ())
567
+ if (sourceBElem.isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 () )
566
568
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
567
- if (sourceBElem.isFloat8E4M3FNUZ ())
569
+ if (sourceBElem.isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN () )
568
570
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
569
571
}
570
572
}
571
573
572
- if (sourceElem.isFloat8E4M3FNUZ () && destElem.isF32 () && chipset >= kGfx940 ) {
574
+ if (destElem.isF32 () &&
575
+ ((sourceElem.isFloat8E4M3FNUZ () && chipset >= kGfx940 ) ||
576
+ (sourceElem.isFloat8E4M3FN () && chipset.hasOcpFp8 ()))) {
573
577
Type sourceBElem =
574
578
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
575
579
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
576
- if (sourceBElem.isFloat8E5M2FNUZ ())
580
+ if (sourceBElem.isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 () )
577
581
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
578
- if (sourceBElem.isFloat8E4M3FNUZ ())
582
+ if (sourceBElem.isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN () )
579
583
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
580
584
}
581
585
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
582
- if (sourceBElem.isFloat8E5M2FNUZ ())
586
+ if (sourceBElem.isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 () )
583
587
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
584
- if (sourceBElem.isFloat8E4M3FNUZ ())
588
+ if (sourceBElem.isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN () )
585
589
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
586
590
}
587
591
}
@@ -787,10 +791,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
787
791
}
788
792
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
789
793
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
790
- if (sourceElemType.isFloat8E5M2FNUZ ()) {
794
+ if (sourceElemType.isFloat8E5M2FNUZ () || sourceElemType. isFloat8E5M2 () ) {
791
795
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
792
796
wordSel);
793
- } else if (sourceElemType.isFloat8E4M3FNUZ ()) {
797
+ } else if (sourceElemType.isFloat8E4M3FNUZ () ||
798
+ sourceElemType.isFloat8E4M3FN ()) {
794
799
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
795
800
wordSel);
796
801
}
@@ -822,10 +827,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
822
827
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
823
828
824
829
Value result;
825
- if (resultElemType.isFloat8E5M2FNUZ ())
830
+ if (resultElemType.isFloat8E5M2FNUZ () || resultElemType. isFloat8E5M2 () )
826
831
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
827
832
existing, wordSel);
828
- else if (resultElemType.isFloat8E4M3FNUZ ())
833
+ else if (resultElemType.isFloat8E4M3FNUZ () || resultElemType. isFloat8E4M3FN () )
829
834
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
830
835
existing, wordSel);
831
836
@@ -857,10 +862,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
857
862
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
858
863
859
864
Value result;
860
- if (resultElemType.isFloat8E5M2FNUZ ())
865
+ if (resultElemType.isFloat8E5M2FNUZ () || resultElemType. isFloat8E5M2 () )
861
866
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
862
867
existing, byteSel);
863
- else if (resultElemType.isFloat8E4M3FNUZ ())
868
+ else if (resultElemType.isFloat8E4M3FNUZ () || resultElemType. isFloat8E4M3FN () )
864
869
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
865
870
existing, byteSel);
866
871
0 commit comments