Skip to content

Commit ab7b9f1

Browse files
committed
[MLIR][AMDGPU] Implement emulated FP8 for the OCP formats. This part mostly just allows the new types.
1 parent 7ba4968 commit ab7b9f1

File tree

6 files changed

+49
-28
lines changed

6 files changed

+49
-28
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
4141

4242
def AMDGPU_ExtPackedFp8Op :
4343
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
44-
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
45-
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
44+
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
45+
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
4646
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
4747
Results<(outs F32:$res)> {
4848
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
6868
Arguments<(ins F32:$sourceA,
6969
Optional<F32>:$sourceB,
7070
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
71-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
72-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
71+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
72+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
7373
let summary = "Round two floats into a packed vector of 8-bit floats";
7474
let description = [{
7575
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
9595
Arguments<(ins F32:$source,
9696
I32:$stochiasticParam,
9797
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
98-
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
99-
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
98+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
99+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
100100
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
101101
let description = [{
102102
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
546546
VectorOfLengthAndType<[4], [F16]>,
547547
VectorOfLengthAndType<[2, 4], [BF16]>,
548548
VectorOfLengthAndType<[4, 8], [I8]>,
549-
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
549+
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
550550
def MFMAOutTypes : AnyTypeOf<[F64,
551551
VectorOfLengthAndType<[4, 16, 32], [F32]>,
552552
VectorOfLengthAndType<[4, 16, 32], [I32]>,

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ struct Chipset {
4747
DEFINE_COMP_OPERATOR(>)
4848
DEFINE_COMP_OPERATOR(>=)
4949
#undef DEFINE_COMP_OPERATOR
50+
51+
bool isGfx940() const {
52+
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
53+
}
54+
bool hasOcpFp8() const {
55+
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
56+
}
5057
};
5158

5259
} // namespace mlir::amdgpu

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -550,38 +550,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
550550
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
551551
}
552552

553-
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
553+
if (destElem.isF32() &&
554+
((sourceElem.isFloat8E5M2FNUZ() && chipset >= kGfx940) ||
555+
(sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
554556
// Known to be correct because there are no scalar f8 instructions and
555557
// because a length mismatch will have been caught by the verifier.
556558
Type sourceBElem =
557559
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
558560
if (m == 16 && n == 16 && k == 32 && b == 1) {
559-
if (sourceBElem.isFloat8E5M2FNUZ())
561+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
560562
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
561-
if (sourceBElem.isFloat8E4M3FNUZ())
563+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
562564
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
563565
}
564566
if (m == 32 && n == 32 && k == 16 && b == 1) {
565-
if (sourceBElem.isFloat8E5M2FNUZ())
567+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
566568
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
567-
if (sourceBElem.isFloat8E4M3FNUZ())
569+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
568570
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
569571
}
570572
}
571573

572-
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
574+
if (destElem.isF32() &&
575+
((sourceElem.isFloat8E4M3FNUZ() && chipset >= kGfx940) ||
576+
(sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
573577
Type sourceBElem =
574578
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
575579
if (m == 16 && n == 16 && k == 32 && b == 1) {
576-
if (sourceBElem.isFloat8E5M2FNUZ())
580+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
577581
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
578-
if (sourceBElem.isFloat8E4M3FNUZ())
582+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
579583
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
580584
}
581585
if (m == 32 && n == 32 && k == 16 && b == 1) {
582-
if (sourceBElem.isFloat8E5M2FNUZ())
586+
if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
583587
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
584-
if (sourceBElem.isFloat8E4M3FNUZ())
588+
if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
585589
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
586590
}
587591
}
@@ -787,10 +791,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
787791
}
788792
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
789793
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
790-
if (sourceElemType.isFloat8E5M2FNUZ()) {
794+
if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
791795
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
792796
wordSel);
793-
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
797+
} else if (sourceElemType.isFloat8E4M3FNUZ() ||
798+
sourceElemType.isFloat8E4M3FN()) {
794799
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
795800
wordSel);
796801
}
@@ -822,10 +827,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
822827
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
823828

824829
Value result;
825-
if (resultElemType.isFloat8E5M2FNUZ())
830+
if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
826831
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
827832
existing, wordSel);
828-
else if (resultElemType.isFloat8E4M3FNUZ())
833+
else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
829834
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
830835
existing, wordSel);
831836

@@ -857,10 +862,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
857862
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
858863

859864
Value result;
860-
if (resultElemType.isFloat8E5M2FNUZ())
865+
if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
861866
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
862867
existing, byteSel);
863-
else if (resultElemType.isFloat8E4M3FNUZ())
868+
else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
864869
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
865870
existing, byteSel);
866871

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8686
return failure();
8787
inType = inVecType.getElementType();
8888
}
89-
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
89+
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
90+
inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
9091
}
9192

9293
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +217,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
216217
if (inType && inType.getWidth() <= 8 && saturateFP8)
217218
// Conversion between 8-bit floats is not supported with truncation enabled.
218219
return failure();
219-
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
220+
221+
return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
222+
chipset.isGfx940()) ||
223+
((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
224+
chipset.hasOcpFp8())));
220225
}
221226

222227
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,16 @@ LogicalResult MFMAOp::verify() {
272272
}
273273

274274
Type sourceBType = getSourceB().getType();
275-
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
275+
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ() ||
276+
sourceElem.isFloat8E5M2() || sourceElem.isFloat8E4M3FN()) {
276277
int64_t sourceBLen = 1;
277278
Type sourceBElem = sourceBType;
278279
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279280
sourceBLen = sourceBVector.getNumElements();
280281
sourceBElem = sourceBVector.getElementType();
281282
}
282-
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
283+
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ() &&
284+
!sourceBElem.isFloat8E5M2() && !sourceBElem.isFloat8E4M3FN())
283285
return emitOpError("expected both source operands to have f8 elements");
284286
if (sourceLen != sourceBLen)
285287
return emitOpError(

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
509509
if (isa<FloatType>(type)) {
510510
if (profile == TosaProfileEnum::BaseInference)
511511
return false;
512-
return type.isF32() || type.isF16() || type.isBF16();
512+
return type.isF32() || type.isF16() || type.isBF16() ||
513+
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
514+
type.isFloat8E4M3FN() || type.isFloat8E5M2();
513515
}
514516
if (auto intTy = dyn_cast<IntegerType>(type)) {
515517
if (intTy.isUnsigned()) {

0 commit comments

Comments
 (0)