Skip to content

Commit 1b790e3

Browse files
committed
[MLIR][AMDGPU] Clean up and redo after other recent patches here.
1 parent 1848df4 commit 1b790e3

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ struct Chipset {
4949
#undef DEFINE_COMP_OPERATOR
5050

5151
bool isGfx940() const {
52-
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
52+
return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
5353
}
5454
bool hasOcpFp8() const {
55-
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
55+
return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
5656
}
5757
};
5858

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
771771
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
772772
ConversionPatternRewriter &rewriter) const {
773773
Location loc = op.getLoc();
774-
if (chipset.majorVersion != 9 || chipset < kGfx940)
774+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
775775
return rewriter.notifyMatchFailure(
776776
loc, "Fp8 conversion instructions are not available on target "
777777
"architecture and their emulation is not implemented");
@@ -815,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
815815
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
816816
ConversionPatternRewriter &rewriter) const {
817817
Location loc = op.getLoc();
818-
if (chipset.majorVersion != 9 || chipset < kGfx940)
818+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
819819
return rewriter.notifyMatchFailure(
820820
loc, "Fp8 conversion instructions are not available on target "
821821
"architecture and their emulation is not implemented");
@@ -852,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
852852
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
853853
ConversionPatternRewriter &rewriter) const {
854854
Location loc = op.getLoc();
855-
if (chipset.majorVersion != 9 || chipset < kGfx940)
855+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
856856
return rewriter.notifyMatchFailure(
857857
loc, "Fp8 conversion instructions are not available on target "
858858
"architecture and their emulation is not implemented");

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
4141
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4242
using OpRewritePattern::OpRewritePattern;
4343

44+
Chipset chipset;
45+
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
46+
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
47+
4448
LogicalResult match(arith::ExtFOp op) const override;
4549
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
4650
};
@@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final
6872

6973
} // end namespace
7074

75+
static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
76+
if (chipset.isGfx940())
77+
return success(elementType.isFloat8E5M2FNUZ() ||
78+
elementType.isFloat8E4M3FNUZ());
79+
if (chipset.hasOcpFp8())
80+
return success(elementType.isFloat8E5M2() || elementType.isFloat8E4M3FN());
81+
return failure();
82+
}
83+
7184
static Value castF32To(Type elementType, Value f32, Location loc,
7285
PatternRewriter &rewriter) {
7386
if (elementType.isF32())
@@ -86,8 +99,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8699
return failure();
87100
inType = inVecType.getElementType();
88101
}
89-
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
90-
inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
102+
return isSupportedFp8(inType, chipset);
91103
}
92104

93105
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -218,10 +230,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
218230
// Conversion between 8-bit floats is not supported with truncation enabled.
219231
return failure();
220232

221-
return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
222-
chipset.isGfx940()) ||
223-
((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
224-
chipset.hasOcpFp8())));
233+
return isSupportedFp8(outType, chipset);
225234
}
226235

227236
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -370,7 +379,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
370379
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
371380

372381
if (convertFP8Arithmetic) {
373-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
382+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
374383
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
375384
saturateFP8Truncf, chipset);
376385
}
@@ -389,7 +398,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
389398
}
390399

391400
bool convertFP8Arithmetic =
392-
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
401+
maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
393402
arith::populateArithToAMDGPUConversionPatterns(
394403
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
395404
*maybeChipset);

0 commit comments

Comments
 (0)