Skip to content

Commit 95d526f

Browse files
authored
[MLIR][Tosa] Fix argmax NaN propagate lowering (llvm#133074)
In the propagate mode, NaN compare equal to each other so in case of several NaNs the index of the first one needs to be returned. This commit changes the index update condition to check that the current index is not that of a NaN. The commit also simplifies argmax NaN ignore lowering to only use OGT. This prevent any update in case of NaN. The only case where the index of a NaN is returned is when all values are NaN and this is covered by the fact that the initial index value is 0 so no update will result in 0 being returned.
1 parent 931a78a commit 95d526f

File tree

2 files changed

+22
-31
lines changed

2 files changed

+22
-31
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

+16-24
Original file line numberDiff line numberDiff line change
@@ -2285,8 +2285,22 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
22852285

22862286
Value predicate;
22872287
if (isa<FloatType>(inElementTy)) {
2288-
predicate = rewriter.create<arith::CmpFOp>(
2289-
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2288+
if (argmaxOp.getNanMode() == "IGNORE") {
2289+
// Only update index & max value for non NaN values. If all
2290+
// values are NaNs, the initial index will be return which is 0.
2291+
predicate = rewriter.create<arith::CmpFOp>(
2292+
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2293+
} else {
2294+
// Update max value if either of the following is true:
2295+
// - new value is bigger
2296+
// - cur max is not NaN and new value is NaN
2297+
Value gt = rewriter.create<arith::CmpFOp>(
2298+
nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2299+
Value oldNonNaN = rewriter.create<arith::CmpFOp>(
2300+
nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2301+
predicate = rewriter.create<arith::AndIOp>(
2302+
nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2303+
}
22902304
} else if (isa<IntegerType>(inElementTy)) {
22912305
predicate = rewriter.create<arith::CmpIOp>(
22922306
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
@@ -2299,28 +2313,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
22992313
nestedLoc, predicate, newValue, oldValue);
23002314
auto resultIndex = rewriter.create<arith::SelectOp>(
23012315
nestedLoc, predicate, newIndex, oldIndex);
2302-
2303-
// Check if we need to materialize compare and select for the given
2304-
// NaN propagation mode.
2305-
2306-
// "PROPAGATE" matches the default NaN propagation mode of the arith
2307-
// dialect so no compare and select is required.
2308-
//
2309-
// In the case "IGNORE" we check if the current argument is NaN and
2310-
// select the old index and value otherwise take the updated index and
2311-
// value.
2312-
if (const auto nanMode = argmaxOp.getNanMode();
2313-
isa<FloatType>(inElementTy) && nanMode == "IGNORE") {
2314-
// Unordered comparison of NaN against itself will always return
2315-
// true.
2316-
Value isNaN = rewriter.create<arith::CmpFOp>(
2317-
argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
2318-
newValue);
2319-
resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
2320-
oldValue, resultMax);
2321-
resultIndex = rewriter.create<arith::SelectOp>(
2322-
nestedLoc, isNaN, oldIndex, resultIndex);
2323-
}
23242316
nestedBuilder.create<linalg::YieldOp>(
23252317
nestedLoc, ValueRange({resultIndex, resultMax}));
23262318
});

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

+6-7
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,9 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
15251525
// CHECK: arith.constant -3.40282347E+38 : f32
15261526
// CHECK: linalg.index
15271527
// CHECK: arith.index_cast
1528-
// CHECK: arith.cmpf ogt
1528+
// CHECK: arith.cmpf ugt
1529+
// CHECK: arith.cmpf ord
1530+
// CHECK: andi
15291531
// CHECK: select
15301532
// CHECK: select
15311533
// CHECK: linalg.yield
@@ -2230,12 +2232,12 @@ func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
22302232
// CHECK-LABEL: @argmax_nan_propagate
22312233
func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
22322234
// CHECK: linalg.generic
2233-
// CHECK: arith.cmpf ogt
2235+
// CHECK: arith.cmpf ugt
2236+
// CHECK: arith.cmpf ord
2237+
// CHECK: andi
22342238
// CHECK: arith.select
22352239
// CHECK: arith.select
22362240
// CHECK-NOT: arith.cmpf uno
2237-
// CHECK-NOT: arith.cmpf uno
2238-
// CHECK-NOT: arith.select
22392241
// CHECK-NOT: arith.select
22402242
// CHECK: linalg.yield
22412243
%11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<4xi32>
@@ -2267,9 +2269,6 @@ func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
22672269
// CHECK: arith.cmpf ogt
22682270
// CHECK: arith.select
22692271
// CHECK: arith.select
2270-
// CHECK: arith.cmpf uno
2271-
// CHECK: arith.select
2272-
// CHECK: arith.select
22732272
// CHECK: linalg.yield
22742273
%12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<4xi32>
22752274
return

0 commit comments

Comments
 (0)