Skip to content

Commit b00ddce

Browse files
authored
[mlir][affine] Fix a crash when cast incompatible type (#145162)
This PR fixes a crash in `getSemiAffineExprFromFlatForm` when localExpr is not `AffineBinaryOpExpr`. Fixes #144091.
1 parent f9fce49 commit b00ddce

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
11741174
// the indices in `coefficients` map, and affine expression corresponding to
11751175
// in indices in `indexToExprMap` map.
11761176
for (const auto &it : llvm::enumerate(localExprs)) {
1177-
AffineExpr expr = it.value();
11781177
if (flatExprs[numDims + numSymbols + it.index()] == 0)
11791178
continue;
1180-
AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
1181-
AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
1179+
AffineExpr expr = it.value();
1180+
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181+
if (!binaryExpr)
1182+
continue;
1183+
1184+
AffineExpr lhs = binaryExpr.getLHS();
1185+
AffineExpr rhs = binaryExpr.getRHS();
11821186
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
11831187
(isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
11841188
isa<AffineConstantExpr>(rhs)))) {

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
592592
//CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
593593
return %a : index
594594
}
595+
596+
// -----
597+
598+
// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
599+
func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
600+
%c0 = arith.constant 0 : index
601+
%c1 = arith.constant 1 : index
602+
%c13 = arith.constant 13 : index
603+
%dim = tensor.dim %arg0, %c0 : tensor<?xf32>
604+
%a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
605+
%b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
606+
// CHECK: %[[C6:.*]] = arith.constant 6 : index
607+
// CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
608+
// CHECK-NEXT: return %[[C6]], %[[C7]]
609+
return %a, %b : index, index
610+
}

0 commit comments

Comments
 (0)