Skip to content

Commit 0272474

Browse files
committed
[MLIR][AMDGPU] Add OCP FP8 support to for new hardware
Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware.
1 parent d779685 commit 0272474

File tree

6 files changed

+49
-28
lines changed

6 files changed

+49
-28
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
4141

4242
def AMDGPU_ExtPackedFp8Op :
4343
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
44-
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
45-
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
44+
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
45+
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
4646
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
4747
Results<(outs F32:$res)> {
4848
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
6868
Arguments<(ins F32:$sourceA,
6969
Optional<F32>:$sourceB,
7070
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
71-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
72-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
71+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
72+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
7373
let summary = "Round two floats into a packed vector of 8-bit floats";
7474
let description = [{
7575
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
9595
Arguments<(ins F32:$source,
9696
I32:$stochiasticParam,
9797
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
98-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
99-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
98+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
99+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
100100
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
101101
let description = [{
102102
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
546546
VectorOfLengthAndType<[4], [F16]>,
547547
VectorOfLengthAndType<[2, 4], [BF16]>,
548548
VectorOfLengthAndType<[4, 8], [I8]>,
549-
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
549+
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
550550
def MFMAOutTypes : AnyTypeOf<[F64,
551551
VectorOfLengthAndType<[4, 16, 32], [F32]>,
552552
VectorOfLengthAndType<[4, 16, 32], [I32]>,

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ struct Chipset {
4747
DEFINE_COMP_OPERATOR(>)
4848
DEFINE_COMP_OPERATOR(>=)
4949
#undef DEFINE_COMP_OPERATOR
50+
51+
bool isGfx940() const {
52+
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
53+
}
54+
bool hasOcpFp8() const {
55+
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
56+
}
5057
};
5158

5259
} // namespace mlir::amdgpu

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -550,38 +550,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
550550
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
551551
}
552552

553-
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
553+
if (destElem.isF32() &&
554+
((sourceElem.isFloat8E5M2FNUZ() && chipset >= kGfx940) ||
555+
(sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
554556
// Known to be correct because there are no scalar f8 instructions and
555557
// because a length mismatch will have been caught by the verifier.
556558
Type sourceBElem =
557559
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
558560
if (m == 16 && n == 16 && k == 32 && b == 1) {
559-
if (sourceBElem.isFloat8E5M2FNUZ())
561+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
560562
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
561-
if (sourceBElem.isFloat8E4M3FNUZ())
563+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
562564
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
563565
}
564566
if (m == 32 && n == 32 && k == 16 && b == 1) {
565-
if (sourceBElem.isFloat8E5M2FNUZ())
567+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
566568
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
567-
if (sourceBElem.isFloat8E4M3FNUZ())
569+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
568570
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
569571
}
570572
}
571573

572-
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
574+
if (destElem.isF32() &&
575+
((sourceElem.isFloat8E4M3FNUZ() && chipset >= kGfx940) ||
576+
(sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
573577
Type sourceBElem =
574578
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
575579
if (m == 16 && n == 16 && k == 32 && b == 1) {
576-
if (sourceBElem.isFloat8E5M2FNUZ())
580+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
577581
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
578-
if (sourceBElem.isFloat8E4M3FNUZ())
582+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
579583
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
580584
}
581585
if (m == 32 && n == 32 && k == 16 && b == 1) {
582-
if (sourceBElem.isFloat8E5M2FNUZ())
586+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
583587
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
584-
if (sourceBElem.isFloat8E4M3FNUZ())
588+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
585589
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
586590
}
587591
}
@@ -787,10 +791,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
787791
}
788792
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
789793
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
790-
if (sourceElemType.isFloat8E5M2FNUZ()) {
794+
if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
791795
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
792796
wordSel);
793-
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
797+
} else if (sourceElemType.isFloat8E4M3FNUZ() ||
798+
sourceElemType.isFloat8E4M3FN()) {
794799
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
795800
wordSel);
796801
}
@@ -822,10 +827,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
822827
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
823828

824829
Value result;
825-
if (resultElemType.isFloat8E5M2FNUZ())
830+
if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
826831
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
827832
existing, wordSel);
828-
else if (resultElemType.isFloat8E4M3FNUZ())
833+
else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
829834
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
830835
existing, wordSel);
831836

@@ -857,10 +862,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
857862
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
858863

859864
Value result;
860-
if (resultElemType.isFloat8E5M2FNUZ())
865+
if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
861866
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
862867
existing, byteSel);
863-
else if (resultElemType.isFloat8E4M3FNUZ())
868+
else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
864869
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
865870
existing, byteSel);
866871

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8686
return failure();
8787
inType = inVecType.getElementType();
8888
}
89-
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
89+
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
90+
inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
9091
}
9192

9293
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +217,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
216217
if (inType && inType.getWidth() <= 8 && saturateFP8)
217218
// Conversion between 8-bit floats is not supported with truncation enabled.
218219
return failure();
219-
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
220+
221+
return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
222+
chipset.isGfx940()) ||
223+
((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
224+
chipset.hasOcpFp8())));
220225
}
221226

222227
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,16 @@ LogicalResult MFMAOp::verify() {
272272
}
273273

274274
Type sourceBType = getSourceB().getType();
275-
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
275+
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ() ||
276+
sourceElem.isFloat8E5M2() || sourceElem.isFloat8E4M3FN()) {
276277
int64_t sourceBLen = 1;
277278
Type sourceBElem = sourceBType;
278279
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279280
sourceBLen = sourceBVector.getNumElements();
280281
sourceBElem = sourceBVector.getElementType();
281282
}
282-
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
283+
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ() &&
284+
!sourceBElem.isFloat8E5M2() && !sourceBElem.isFloat8E4M3FN())
283285
return emitOpError("expected both source operands to have f8 elements");
284286
if (sourceLen != sourceBLen)
285287
return emitOpError(

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
509509
if (isa<FloatType>(type)) {
510510
if (profile == TosaProfileEnum::BaseInference)
511511
return false;
512-
return type.isF32() || type.isF16() || type.isBF16();
512+
return type.isF32() || type.isF16() || type.isBF16() ||
513+
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
514+
type.isFloat8E4M3FN() || type.isFloat8E5M2();
513515
}
514516
if (auto intTy = dyn_cast<IntegerType>(type)) {
515517
if (intTy.isUnsigned()) {

0 commit comments

Comments
 (0)