Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :

def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
Original file line number Diff line number Diff line change
@@ -49,6 +49,11 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
};

inline bool hasOcpFp8(const Chipset &chipset) {
return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
chipset.majorVersion >= 12;
}

} // namespace mlir::amdgpu

#endif
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
@@ -132,6 +132,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
/// Return true if this is an float type (with the specified width).
bool isFloat() const;
bool isFloat(unsigned width) const;

/// Return true if this is an integer type (with the specified width).
bool isInteger() const;
54 changes: 33 additions & 21 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
@@ -474,6 +474,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}

/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
}

/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
}

/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -570,40 +584,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}

if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
chipset >= kGfx942) {
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (isa<Float8E5M2FNUZType>(sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (isa<Float8E4M3FNUZType>(sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (isa<Float8E5M2FNUZType>(sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (isa<Float8E4M3FNUZType>(sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}

if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
chipset >= kGfx942) {
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (isa<Float8E5M2FNUZType>(sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (isa<Float8E4M3FNUZType>(sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (isa<Float8E5M2FNUZType>(sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (isa<Float8E4M3FNUZType>(sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -781,7 +793,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx942)
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -811,10 +823,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
if (isa<Float8E5M2FNUZType>(sourceElemType)) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -825,7 +837,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx942)
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -846,10 +858,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());

Value result;
if (isa<Float8E5M2FNUZType>(resultElemType))
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
else if (isa<Float8E4M3FNUZType>(resultElemType))
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);

@@ -862,7 +874,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx942)
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -881,10 +893,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());

Value result;
if (isa<Float8E5M2FNUZType>(resultElemType))
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
else if (isa<Float8E4M3FNUZType>(resultElemType))
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);

24 changes: 20 additions & 4 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
@@ -30,6 +30,9 @@ using namespace mlir;
using namespace mlir::amdgpu;

namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx942 = Chipset(9, 4, 2);

struct ArithToAMDGPUConversionPass final
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
using impl::ArithToAMDGPUConversionPassBase<
@@ -41,6 +44,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;

Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}

LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
@@ -68,6 +75,14 @@ struct TruncfToFloat16RewritePattern final

} // end namespace

static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
if (chipset == kGfx942)
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
if (hasOcpFp8(chipset))
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
return failure();
}

static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
@@ -86,7 +101,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
return isSupportedF8(inType, chipset);
}

void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -219,7 +234,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));

return isSupportedF8(outType, chipset);
}

void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -365,7 +381,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {

if (convertFP8Arithmetic) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
@@ -384,7 +400,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}

bool convertFP8Arithmetic =
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 2);
*maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}

Type sourceBType = getSourceB().getType();
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
if (sourceElem.isFloat(8)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
if (!sourceBElem.isFloat(8))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
@@ -696,7 +696,8 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {

bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
return type.isF32() || type.isF16() || type.isBF16();
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
Float8E5M2Type>(type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are missing two f8 types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the TOSA spec doesn't support them #106160 (comment)
This was addressed in 5bace46

} else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
9 changes: 9 additions & 0 deletions mlir/lib/IR/Types.cpp
Original file line number Diff line number Diff line change
@@ -42,6 +42,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }

bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }

/// Return true if this is a float type with the specified width.
bool Type::isFloat(unsigned width) const {
if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
return fltTy.getWidth() == width;
return false;
}

bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }

bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }
109 changes: 109 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s

// CHECK-LABEL: func @ext_scalar
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
// CHECK: return [[EXT]]
func.func @ext_scalar(%v: f8E5M2) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
func.return %ret : f32
}

// CHECK-LABEL: func @ext_short_vec
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
// CHECK: return [[EXT]]
func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
func.return %ret : f32
}

// CHECK-LABEL: func @ext_full_vec(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
// CHECK: return [[EXT]] : f32

func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
func.return %ret : f32
}

// CHECK-LABEL: func @packed_trunc
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
%ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FN>
func.return %ret : vector<4xf8E4M3FN>
}

// CHECK-LABEL: func @packed_truncx2
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FN>
func.return %ret : vector<4xf8E4M3FN>
}

// CHECK-LABEL: func @packed_truncx2_into
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
func.return %ret : vector<4xf8E5M2>
}

// CHECK-LABEL: func @packed_stoch_round
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FN>
func.return %ret : vector<4xf8E4M3FN>
}

// CHECK-LABEL: func @packed_stoch_round_into
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
func.return %ret : vector<4xf8E5M2>
}
58 changes: 58 additions & 0 deletions mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: mlir-opt --split-input-file %s \
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx950 saturate-fp8-truncf=true}))' \
// RUN: | FileCheck %s

// RUN: mlir-opt --split-input-file %s \
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx1200 saturate-fp8-truncf=true}))' \
// RUN: | FileCheck %s

// CHECK-LABEL: func.func @scalar_trunc
// CHECK-SAME: ([[V:%.+]]: f16)
// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2>
// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2>
// CHECK: return [[W]] : f8E5M2
func.func @scalar_trunc(%v: f16) -> f8E5M2 {
%w = arith.truncf %v : f16 to f8E5M2
return %w : f8E5M2
}

// No 0-D test because arith.truncf hasn't been extended to support it.

// -----

// CHECK-LABEL: func.func @vector_trunc
// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FN> {
// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-4.480000e+02> : vector<2xf32>
// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<4.480000e+02> : vector<2xf32>
// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0]
// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1]
// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FN>
// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FN> to vector<2xf8E4M3FN>
// CHECK: return [[W]] : vector<2xf8E4M3FN>
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FN> {
%w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FN>
return %w : vector<2xf8E4M3FN>
}
176 changes: 176 additions & 0 deletions mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s

// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2)
// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
// CHECK: return [[W]]
func.func @scalar_ext(%v: f8E5M2) -> f16 {
%w = arith.extf %v : f8E5M2 to f16
return %w : f16
}

// No 0-D test because arith.extf hasn't been extended to support it.

// -----

// CHECK-LABEL: func.func @vector_ext_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>)
// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32
// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32
// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
// CHECK: return [[W1]] : vector<2xf64>

func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
%w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64>
return %w : vector<2xf64>
}

// -----

// CHECK-LABEL: func.func @vector_ext_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>)
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
// CHECK: [[W0:%.+]] = vector.insert [[F0]]
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]

// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]

// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
// CHECK: return [[W8]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
return %w : vector<9xf32>
}

// -----

// CHECK-LABEL: func.func @scalar_trunc
// CHECK-SAME: ([[V:%.+]]: f16)
// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2>
// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2>
// CHECK: return [[W]] : f8E5M2
func.func @scalar_trunc(%v: f16) -> f8E5M2 {
%w = arith.truncf %v : f16 to f8E5M2
return %w : f8E5M2
}

// No 0-D test because arith.truncf hasn't been extended to support it.

// -----

// CHECK-LABEL: func.func @vector_trunc_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2> {
// CHECK: [[V0:%.+]] = vector.extract [[V]][0]
// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
// CHECK: [[V1:%.+]] = vector.extract [[V]][1]
// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2>
// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
// CHECK: return [[W]] : vector<2xf8E5M2>
func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2> {
%w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2>
return %w : vector<2xf8E5M2>
}

// -----

// CHECK-LABEL: func.func @vector_trunc_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf32>)
// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN>
// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}

// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}

// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
// CHECK: return [[W]]
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FN> {
%w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FN>
return %w : vector<9xf8E4M3FN>
}

// -----

// CHECK-LABEL: func.func @vector_trunc_long_2d
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN>
// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}

// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}

// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FN> to vector<1x9xf8E4M3FN>
// CHECK: return [[RE]]
func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
%w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FN>
return %w : vector<1x9xf8E4M3FN>
}

// -----

// CHECK-LABEL: func.func @vector_ext_long_2d
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
// CHECK: [[W0:%.+]] = vector.insert [[F0]]
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]

// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]

// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
// CHECK: return [[CAST]]
func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
%w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
return %w : vector<1x9xf32>
}