Skip to content

[MLIR][AMDGPU] Add OCP FP8 support for new hardware #106160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :

def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
Expand All @@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
Expand All @@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
Expand Down Expand Up @@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
};

inline bool isGfx940Series(const Chipset &chipset) {
return chipset.majorVersion == 9 && chipset.minorVersion == 4;
}
inline bool hasOcpFp8(const Chipset &chipset) {
return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
chipset.majorVersion >= 12;
}

} // namespace mlir::amdgpu

#endif
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
/// Return true if this is an float type (with the specified width).
bool isFloat() const;
bool isFloat(unsigned width) const;
Comment on lines +143 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


/// Return true if this is an integer type (with the specified width).
bool isInteger() const;
Expand Down
52 changes: 33 additions & 19 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}

/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) ||
(hasOcpFp8(chipset) && type.isFloat8E5M2());
}

/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) ||
(hasOcpFp8(chipset) && type.isFloat8E4M3FN());
}

/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
Expand Down Expand Up @@ -550,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}

if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}

if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
Expand Down Expand Up @@ -757,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down Expand Up @@ -787,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
if (sourceElemType.isFloat8E5M2FNUZ()) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
Expand All @@ -801,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand All @@ -822,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());

Value result;
if (resultElemType.isFloat8E5M2FNUZ())
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
else if (resultElemType.isFloat8E4M3FNUZ())
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);

Expand All @@ -838,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand All @@ -857,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());

Value result;
if (resultElemType.isFloat8E5M2FNUZ())
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
else if (resultElemType.isFloat8E4M3FNUZ())
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);

Expand Down
21 changes: 17 additions & 4 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;

Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}

LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
Expand Down Expand Up @@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final

} // end namespace

static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
if (isGfx940Series(chipset))
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
if (hasOcpFp8(chipset))
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
return failure();
}

static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
Expand All @@ -86,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
return isSupportedF8(inType, chipset);
}

void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Expand Down Expand Up @@ -216,7 +228,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());

return isSupportedF8(outType, chipset);
}

void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Expand Down Expand Up @@ -365,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {

if (convertFP8Arithmetic) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
Expand All @@ -384,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}

bool convertFP8Arithmetic =
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}

Type sourceBType = getSourceB().getType();
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
if (sourceElem.isFloat(8)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
if (!sourceBElem.isFloat(8))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,8 @@ bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
return type.isF32() || type.isF16() || type.isBF16();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TOSA spec only supports the OCP versions of the FP8 formats, and this pass is verifying adherence to the spec, so we should only add the Float8E4M3FNType and Float8E5M2Type in this location.

return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNUZType,
Float8E5M2FNUZType, Float8E4M3FNType, Float8E5M2Type>(type);
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }

bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }

/// Return true if this is an integer type with the specified width.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo integer should be float?

bool Type::isFloat(unsigned width) const {
if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
return fltTy.getWidth() == width;
return false;
}

bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }

bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }
Expand Down
Loading
Loading