@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
41
41
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
42
42
using OpRewritePattern::OpRewritePattern;
43
43
44
+ Chipset chipset;
45
+ ExtFOnFloat8RewritePattern (MLIRContext *ctx, Chipset chipset)
46
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
47
+
44
48
LogicalResult match (arith::ExtFOp op) const override ;
45
49
void rewrite (arith::ExtFOp op, PatternRewriter &rewriter) const override ;
46
50
};
@@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final
68
72
69
73
} // end namespace
70
74
75
+ static LogicalResult isSupportedFp8 (Type elementType, Chipset chipset) {
76
+ if (chipset.isGfx940 ())
77
+ return success (elementType.isFloat8E5M2FNUZ () ||
78
+ elementType.isFloat8E4M3FNUZ ());
79
+ if (chipset.hasOcpFp8 ())
80
+ return success (elementType.isFloat8E5M2 () || elementType.isFloat8E4M3FN ());
81
+ return failure ();
82
+ }
83
+
71
84
static Value castF32To (Type elementType, Value f32 , Location loc,
72
85
PatternRewriter &rewriter) {
73
86
if (elementType.isF32 ())
@@ -86,8 +99,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
86
99
return failure ();
87
100
inType = inVecType.getElementType ();
88
101
}
89
- return success (inType.isFloat8E5M2FNUZ () || inType.isFloat8E4M3FNUZ () ||
90
- inType.isFloat8E5M2 () || inType.isFloat8E4M3FN ());
102
+ return isSupportedFp8 (inType, chipset);
91
103
}
92
104
93
105
void ExtFOnFloat8RewritePattern::rewrite (arith::ExtFOp op,
@@ -218,10 +230,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
218
230
// Conversion between 8-bit floats is not supported with truncation enabled.
219
231
return failure ();
220
232
221
- return success ((((outType.isFloat8E5M2FNUZ () || outType.isFloat8E4M3FNUZ ()) &&
222
- chipset.isGfx940 ()) ||
223
- ((outType.isFloat8E5M2 () || outType.isFloat8E4M3FN ()) &&
224
- chipset.hasOcpFp8 ())));
233
+ return isSupportedFp8 (outType, chipset);
225
234
}
226
235
227
236
void TruncFToFloat8RewritePattern::rewrite (arith::TruncFOp op,
@@ -370,7 +379,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
370
379
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
371
380
372
381
if (convertFP8Arithmetic) {
373
- patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
382
+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext (), chipset );
374
383
patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
375
384
saturateFP8Truncf, chipset);
376
385
}
@@ -389,7 +398,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
389
398
}
390
399
391
400
bool convertFP8Arithmetic =
392
- maybeChipset->majorVersion == 9 && * maybeChipset >= Chipset ( 9 , 4 , 0 );
401
+ maybeChipset->isGfx940 () || maybeChipset-> hasOcpFp8 ( );
393
402
arith::populateArithToAMDGPUConversionPatterns (
394
403
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
395
404
*maybeChipset);
0 commit comments