Skip to content

Commit 856d708

Browse files
[mlir][transform] Plumb a simplified form of AffineMin folding into transform.pad-tiling-interface
This revision introduces a simple variant of AffineMin folding in makeComposedFoldedAffineApply and makes use of it in transform.pad-tiling-interface. Since this version explicitly call ValueBoundsInterface, it may be too expensive and is only activate behind a flag. It results in better foldings when mixing tiling and padding, including with dynamic shapes. This should be further composed with #145068 to provide full simplification and address the remaining TODO in the test.
1 parent d6a486c commit 856d708

File tree

6 files changed

+251
-41
lines changed

6 files changed

+251
-41
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,11 @@ void canonicalizeSetAndOperands(IntegerSet *set,
410410
/// other AffineApplyOps supplying those operands. The operands of the resulting
411411
/// AffineApplyOp do not change the length of AffineApplyOp chains.
412412
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
413-
ArrayRef<OpFoldResult> operands);
413+
ArrayRef<OpFoldResult> operands,
414+
bool composeAffineMin = false);
414415
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
415-
ArrayRef<OpFoldResult> operands);
416+
ArrayRef<OpFoldResult> operands,
417+
bool composeAffineMin = false);
416418

417419
/// Constructs an AffineApplyOp that applies `map` to `operands` after composing
418420
/// the map with the maps of any other AffineApplyOp supplying the operands,
@@ -421,16 +423,19 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
421423
/// map.
422424
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
423425
AffineMap map,
424-
ArrayRef<OpFoldResult> operands);
426+
ArrayRef<OpFoldResult> operands,
427+
bool composeAffineMin = false);
425428
/// Variant of `makeComposedFoldedAffineApply` that applies to an expression.
426429
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
427430
AffineExpr expr,
428-
ArrayRef<OpFoldResult> operands);
431+
ArrayRef<OpFoldResult> operands,
432+
bool composeAffineMin = false);
429433
/// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps.
430434
/// Note that this may create as many affine.apply operations as the map has
431435
/// results given that affine.apply must be single-result.
432436
SmallVector<OpFoldResult> makeComposedFoldedMultiResultAffineApply(
433-
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands);
437+
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
438+
bool composeAffineMin = false);
434439

435440
/// Returns an AffineMinOp obtained by composing `map` and `operands` with
436441
/// AffineApplyOps supplying those operands.
@@ -459,7 +464,8 @@ OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
459464
/// terminal symbol, i.e., a symbol defined at the top level or a block/function
460465
/// argument.
461466
void fullyComposeAffineMapAndOperands(AffineMap *map,
462-
SmallVectorImpl<Value> *operands);
467+
SmallVectorImpl<Value> *operands,
468+
bool composeAffineMin = false);
463469

464470
} // namespace affine
465471
} // namespace mlir

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class ValueBoundsConstraintSet
135135

136136
/// Construct a variable for a map and its operands.
137137
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
138-
Variable(AffineMap map, ArrayRef<Value> mapOperands);
138+
Variable(AffineMap map, ValueRange mapOperands);
139139

140140
MLIRContext *getContext() const { return map.getContext(); }
141141

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 103 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1212
#include "mlir/Dialect/UB/IR/UBOps.h"
1313
#include "mlir/Dialect/Utils/StaticValueUtils.h"
14+
#include "mlir/IR/AffineExpr.h"
1415
#include "mlir/IR/AffineExprVisitor.h"
1516
#include "mlir/IR/IRMapping.h"
1617
#include "mlir/IR/IntegerSet.h"
1718
#include "mlir/IR/Matchers.h"
1819
#include "mlir/IR/OpDefinition.h"
1920
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/IR/Value.h"
2022
#include "mlir/Interfaces/ShapedOpInterfaces.h"
2123
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2224
#include "mlir/Transforms/InliningUtils.h"
@@ -26,7 +28,9 @@
2628
#include "llvm/ADT/SmallVectorExtras.h"
2729
#include "llvm/ADT/TypeSwitch.h"
2830
#include "llvm/Support/Debug.h"
31+
#include "llvm/Support/LogicalResult.h"
2932
#include "llvm/Support/MathExtras.h"
33+
#include <limits>
3034
#include <numeric>
3135
#include <optional>
3236

@@ -1042,6 +1046,59 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
10421046
map.getContext());
10431047
}
10441048

1049+
/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`,
1050+
/// replaces the patterns:
1051+
/// ```
1052+
/// dimOrSym.ceildiv(cst) * cst
1053+
/// (dimOrSym + cst - 1).floordiv(cst) * cst
1054+
/// ```
1055+
/// by `cst` in `map`.
1056+
/// This simplification is valid iff `minOp` is guaranteed to be nonnegative.
1057+
/// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to
1058+
/// inject static information that may not be statically discoverable.
1059+
/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1060+
/// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false.
1061+
static LogicalResult replaceAffineMinBoundingBoxExpression(
1062+
AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map,
1063+
bool affineMinKnownToBeNonNegative = false) {
1064+
auto affineMinMap = minOp.getAffineMap();
1065+
if (!affineMinKnownToBeNonNegative) {
1066+
ValueRange values = minOp->getOperands();
1067+
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1068+
AffineMap row = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
1069+
FailureOr<int64_t> lowerBound =
1070+
ValueBoundsConstraintSet::computeConstantBound(
1071+
presburger::BoundType::LB, {row, values},
1072+
/*stopCondition=*/nullptr,
1073+
/*closedUB=*/true);
1074+
if (failed(lowerBound) || lowerBound.value() < 0)
1075+
return failure();
1076+
}
1077+
}
1078+
1079+
AffineMap initialMap = *map;
1080+
for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) {
1081+
auto m = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
1082+
// TODO: this should also work with nonnegative symbolic divisors.
1083+
if (!m.isSingleConstant())
1084+
continue;
1085+
1086+
auto cst = m.getSingleConstantResult();
1087+
DenseMap<AffineExpr, AffineExpr> repl;
1088+
// dimOrSym.ceilDiv(cst) * cst -> cst
1089+
repl[dimOrSym.ceilDiv(cst) * cst] =
1090+
getAffineConstantExpr(cst, minOp.getContext());
1091+
// (dimOrSym + cst - 1).floorDiv(cst) * cst -> cst
1092+
repl[(dimOrSym + cst - 1).floorDiv(cst) * cst] =
1093+
getAffineConstantExpr(cst, minOp.getContext());
1094+
auto newMap = map->replace(repl);
1095+
if (newMap == *map)
1096+
continue;
1097+
*map = newMap;
1098+
}
1099+
return success(*map != initialMap);
1100+
}
1101+
10451102
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
10461103
/// defining AffineApplyOp expression and operands.
10471104
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1052,10 +1109,13 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
10521109
/// 2. `map` dim and symbols are gradually shifted to higher positions.
10531110
/// 3. Old `dim` and `sym` entries are replaced by nullptr
10541111
/// This avoids the need for any bookkeeping.
1112+
/// If `replaceAffineMin` is set to true, additionally triggers more expensive
1113+
/// replacements involving affine_min operations.
10551114
static LogicalResult replaceDimOrSym(AffineMap *map,
10561115
unsigned dimOrSymbolPosition,
10571116
SmallVectorImpl<Value> &dims,
1058-
SmallVectorImpl<Value> &syms) {
1117+
SmallVectorImpl<Value> &syms,
1118+
bool replaceAffineMin) {
10591119
MLIRContext *ctx = map->getContext();
10601120
bool isDimReplacement = (dimOrSymbolPosition < dims.size());
10611121
unsigned pos = isDimReplacement ? dimOrSymbolPosition
@@ -1064,6 +1124,13 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
10641124
if (!v)
10651125
return failure();
10661126

1127+
auto minOp = v.getDefiningOp<AffineMinOp>();
1128+
if (minOp && replaceAffineMin) {
1129+
AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
1130+
: getAffineSymbolExpr(pos, ctx);
1131+
return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map);
1132+
}
1133+
10671134
auto affineApply = v.getDefiningOp<AffineApplyOp>();
10681135
if (!affineApply)
10691136
return failure();
@@ -1101,7 +1168,8 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
11011168
/// iteratively. Perform canonicalization of map and operands as well as
11021169
/// AffineMap simplification. `map` and `operands` are mutated in place.
11031170
static void composeAffineMapAndOperands(AffineMap *map,
1104-
SmallVectorImpl<Value> *operands) {
1171+
SmallVectorImpl<Value> *operands,
1172+
bool composeAffineMin = false) {
11051173
if (map->getNumResults() == 0) {
11061174
canonicalizeMapAndOperands(map, operands);
11071175
*map = simplifyAffineMap(*map);
@@ -1122,7 +1190,8 @@ static void composeAffineMapAndOperands(AffineMap *map,
11221190
while (true) {
11231191
bool changed = false;
11241192
for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1125-
if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1193+
if ((changed |=
1194+
succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
11261195
break;
11271196
if (!changed)
11281197
break;
@@ -1163,38 +1232,41 @@ static void composeAffineMapAndOperands(AffineMap *map,
11631232
}
11641233

11651234
void mlir::affine::fullyComposeAffineMapAndOperands(
1166-
AffineMap *map, SmallVectorImpl<Value> *operands) {
1235+
AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
11671236
while (llvm::any_of(*operands, [](Value v) {
11681237
return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
11691238
})) {
1170-
composeAffineMapAndOperands(map, operands);
1239+
composeAffineMapAndOperands(map, operands, composeAffineMin);
11711240
}
11721241
}
11731242

11741243
AffineApplyOp
11751244
mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
1176-
ArrayRef<OpFoldResult> operands) {
1245+
ArrayRef<OpFoldResult> operands,
1246+
bool composeAffineMin) {
11771247
SmallVector<Value> valueOperands;
11781248
map = foldAttributesIntoMap(b, map, operands, valueOperands);
1179-
composeAffineMapAndOperands(&map, &valueOperands);
1249+
composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
11801250
assert(map);
11811251
return b.create<AffineApplyOp>(loc, map, valueOperands);
11821252
}
11831253

11841254
AffineApplyOp
11851255
mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
1186-
ArrayRef<OpFoldResult> operands) {
1256+
ArrayRef<OpFoldResult> operands,
1257+
bool composeAffineMin) {
11871258
return makeComposedAffineApply(
11881259
b, loc,
11891260
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}, b.getContext())
11901261
.front(),
1191-
operands);
1262+
operands, composeAffineMin);
11921263
}
11931264

11941265
/// Composes the given affine map with the given list of operands, pulling in
11951266
/// the maps from any affine.apply operations that supply the operands.
11961267
static void composeMultiResultAffineMap(AffineMap &map,
1197-
SmallVectorImpl<Value> &operands) {
1268+
SmallVectorImpl<Value> &operands,
1269+
bool composeAffineMin = false) {
11981270
// Compose and canonicalize each expression in the map individually because
11991271
// composition only applies to single-result maps, collecting potentially
12001272
// duplicate operands in a single list with shifted dimensions and symbols.
@@ -1203,7 +1275,8 @@ static void composeMultiResultAffineMap(AffineMap &map,
12031275
for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
12041276
SmallVector<Value> submapOperands(operands.begin(), operands.end());
12051277
AffineMap submap = map.getSubMap({i});
1206-
fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1278+
fullyComposeAffineMapAndOperands(&submap, &submapOperands,
1279+
composeAffineMin);
12071280
canonicalizeMapAndOperands(&submap, &submapOperands);
12081281
unsigned numNewDims = submap.getNumDims();
12091282
submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
@@ -1221,10 +1294,9 @@ static void composeMultiResultAffineMap(AffineMap &map,
12211294
canonicalizeMapAndOperands(&map, &operands);
12221295
}
12231296

1224-
OpFoldResult
1225-
mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1226-
AffineMap map,
1227-
ArrayRef<OpFoldResult> operands) {
1297+
OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
1298+
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1299+
bool composeAffineMin) {
12281300
assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
12291301

12301302
// Create new builder without a listener, so that no notification is
@@ -1236,7 +1308,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12361308

12371309
// Create op.
12381310
AffineApplyOp applyOp =
1239-
makeComposedAffineApply(newBuilder, loc, map, operands);
1311+
makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
12401312

12411313
// Get constant operands.
12421314
SmallVector<Attribute> constOperands(applyOp->getNumOperands());
@@ -1256,26 +1328,25 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12561328
return llvm::getSingleElement(foldResults);
12571329
}
12581330

1259-
OpFoldResult
1260-
mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1261-
AffineExpr expr,
1262-
ArrayRef<OpFoldResult> operands) {
1331+
OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
1332+
OpBuilder &b, Location loc, AffineExpr expr,
1333+
ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
12631334
return makeComposedFoldedAffineApply(
12641335
b, loc,
12651336
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}, b.getContext())
12661337
.front(),
1267-
operands);
1338+
operands, composeAffineMin);
12681339
}
12691340

12701341
SmallVector<OpFoldResult>
12711342
mlir::affine::makeComposedFoldedMultiResultAffineApply(
1272-
OpBuilder &b, Location loc, AffineMap map,
1273-
ArrayRef<OpFoldResult> operands) {
1274-
return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1275-
[&](unsigned i) {
1276-
return makeComposedFoldedAffineApply(
1277-
b, loc, map.getSubMap({i}), operands);
1278-
});
1343+
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1344+
bool composeAffineMin) {
1345+
return llvm::map_to_vector(
1346+
llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
1347+
return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1348+
operands, composeAffineMin);
1349+
});
12791350
}
12801351

12811352
template <typename OpTy>
@@ -3024,7 +3095,8 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result,
30243095
/// `set` by composing the maps of such affine.apply ops with the integer
30253096
/// set constraints.
30263097
static void composeSetAndOperands(IntegerSet &set,
3027-
SmallVectorImpl<Value> &operands) {
3098+
SmallVectorImpl<Value> &operands,
3099+
bool composeAffineMin) {
30283100
// We will simply reuse the API of the map composition by viewing the LHSs of
30293101
// the equalities and inequalities of `set` as the affine exprs of an affine
30303102
// map. Convert to equivalent map, compose, and convert back to set.
@@ -3035,7 +3107,7 @@ static void composeSetAndOperands(IntegerSet &set,
30353107
[](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
30363108
return;
30373109

3038-
composeAffineMapAndOperands(&map, &operands);
3110+
composeAffineMapAndOperands(&map, &operands, composeAffineMin);
30393111
set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
30403112
set.getEqFlags());
30413113
}
@@ -3044,7 +3116,7 @@ static void composeSetAndOperands(IntegerSet &set,
30443116
LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
30453117
auto set = getIntegerSet();
30463118
SmallVector<Value, 4> operands(getOperands());
3047-
composeSetAndOperands(set, operands);
3119+
composeSetAndOperands(set, operands, /*composeAffineMin=*/false);
30483120
canonicalizeSetAndOperands(&set, &operands);
30493121

30503122
// Check if the canonicalization or composition led to any change.

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
8484
getDimsToSize(rewriter, indexingSizes, options);
8585

8686
// For each dimension in the operand's shape, iterate over indexingSizes and
87-
// add
87+
// add the various term contributions.
8888
for (const auto &enResults : enumerate(indexingMap.getResults())) {
8989
int64_t resultIndex = enResults.index();
9090
AffineMap partialIndexingMap = indexingMap.getSubMap(
@@ -122,7 +122,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
122122
AffineMap composedMap = projectedMap.compose(ceilMap);
123123
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
124124
rewriter, loc, composedMap,
125-
{indexingSizes[paddingDim], paddingSize});
125+
{indexingSizes[paddingDim], paddingSize},
126+
/*composeAffineMin=*/true);
126127
terms.push_back(paddingDimOfr);
127128
} else {
128129
// Otherwise just set to paddingSize.

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
146146
}
147147

148148
ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
149-
ArrayRef<Value> mapOperands)
149+
ValueRange mapOperands)
150150
: Variable(map, llvm::map_to_vector(mapOperands,
151151
[](Value v) { return Variable(v); })) {}
152152

0 commit comments

Comments
 (0)