Skip to content

[mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface #145068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,35 @@ def SimplifyBoundedAffineOpsOp
}];
}

def SimplifyMinMaxAffineOpsOp :
Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let description = [{
Simplify the targeted `affine.min` / `affine.max` ops using the
`mlir::affine::simplifyAffineMinMaxOps` transform.

Example:
```
%0 = transform.structured.match ops{["affine.max"]} in %arg1
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
```

#### Return modes

This transform consumes the target handle and does not produce any results.
This transforms definitely fails if any of the targeted operations is not an
`affine.min` or `affine.max` operation, or if the canonicalization patterns
failed to converge.
This transform silently fails if none of the operations were simplified.
Otherwise, it succeeds.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs);
let assemblyFormat = [{
$target attr-dict `:` type($target)
}];
}

#endif // Affine_TRANSFORM_OPS
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace affine {
class AffineApplyOp;
class AffineDelinearizeIndexOp;
class AffineLinearizeIndexOp;
class AffineMaxOp;
class AffineMinOp;

/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
Expand Down Expand Up @@ -127,6 +129,37 @@ OpFoldResult materializeComputedBound(
OpBuilder &b, Location loc, AffineMap boundMap,
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);

/// This transform tries to simplify the affine min operation `op`, by finding a
/// common lower bound for a set of expressions in the affine map results. It
/// returns whether the transform updated `op`'s affine map.
///
/// In concrete terms, given an operation like:
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
/// If `d0 < 128` and `128 < s1 < s0`, the transform will update `op` to:
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);

/// This transform tries to simplify the affine max operation `op`, by finding a
/// common upper bound for a set of expressions in the affine map results. It
/// returns whether the transform updated `op`'s affine map.
///
/// In concrete terms, given an operation like:
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
/// If `d0 > 128` and `s0 > s1 > 128`, the transform will update `op` to:
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s0)>(%i)[%s0, %s1]`.
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);

/// This transform applies `simplifyAffineMinOp` and `simplifyAffineMaxOp` to
/// all the `affine.min` or `affine.max` operations in `ops`. After
/// simplification, it invokes the `affine.min/max` canonicalization patterns on
/// `ops`.
///
/// This transform returns failure if the greedy pattern rewriter failed to
/// converge during canonicalization, otherwise it returns success. If provided,
/// `modified` is set to `true` if the IR was modified in any way.
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter,
ArrayRef<Operation *> ops,
bool *modified = nullptr);
} // namespace affine
} // namespace mlir

Expand Down
25 changes: 24 additions & 1 deletion mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,17 @@ class ValueBoundsConstraintSet

/// Construct a variable for a map and its operands.
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
Variable(AffineMap map, ArrayRef<Value> mapOperands);
Variable(AffineMap map, ValueRange mapOperands);

MLIRContext *getContext() const { return map.getContext(); }

/// Returns the affine map.
AffineMap getMap() const { return map; }

/// Returns the map operands.
ValueDimList &getOperands() { return mapOperands; }
const ValueDimList &getOperands() const { return mapOperands; }

private:
friend class ValueBoundsConstraintSet;
AffineMap map;
Expand Down Expand Up @@ -254,6 +261,12 @@ class ValueBoundsConstraintSet
/// prove the relation or until it ran out of IR.
static bool compare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);
/// This function is similar to `ValueBoundsConstraintSet::compare`, except
/// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
/// relation nor its inverse relation could be proven.
static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
ComparisonOperator cmp,
const Variable &rhs);

/// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
Expand Down Expand Up @@ -327,6 +340,16 @@ class ValueBoundsConstraintSet
/// constraints.
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);

/// Return "true" if, based on the current state of the constraint system,
/// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
/// can be proven. Otherwise, it returns `failure` if neither the relation nor
/// its inverse relation could be proven.
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
int64_t rhsPos);

/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
Expand Down
39 changes: 38 additions & 1 deletion mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -112,7 +113,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
}
if (boundedOps.contains(target)) {
auto diag = emitDefiniteFailure()
<< "target op result must not be constrainted";
<< "target op result must not be constrained";
diag.attachNote(target->getLoc()) << "target/constrained op";
return diag;
}
Expand Down Expand Up @@ -148,6 +149,42 @@ void SimplifyBoundedAffineOpsOp::getEffects(
modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// SimplifyMinMaxAffineOpsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
SmallVector<Operation *> targets;
for (Operation *target : state.getPayloadOps(getTarget())) {
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
auto diag = emitDefiniteFailure()
<< "target must be affine.min or affine.max";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
targets.push_back(target);
}
bool modified = false;
if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
&modified))) {
return emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
}
if (!modified) {
return emitSilenceableError()
<< "the transform failed to simplify any of the target operations";
}
return DiagnosedSilenceableFailure::success();
}

void SimplifyMinMaxAffineOpsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
SimplifyAffineMinMax.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
Expand Down
174 changes: 174 additions & 0 deletions mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a transform to simplify mix/max affine operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "affine-min-max"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

using namespace mlir;
using namespace mlir::affine;

/// Simplifies an affine min/max operation by proving there's a lower or upper
/// bound.
template <typename AffineOp>
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
using Variable = ValueBoundsConstraintSet::Variable;
using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;

AffineMap affineMap = affineOp.getMap();
ValueRange operands = affineOp.getOperands();
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;

LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });

// Create a `Variable` list with values corresponding to each of the results
// in the affine affineMap.
SmallVector<Variable> variables = llvm::map_to_vector(
llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});

// Get the comparison operation.
ComparisonOperator cmpOp =
isMin ? ComparisonOperator::LT : ComparisonOperator::GT;

// Find disjoint sets bounded by a common value.
llvm::IntEqClasses boundedClasses(variables.size());
DenseMap<unsigned, Variable *> bounds;
for (auto &&[i, v] : llvm::enumerate(variables)) {
unsigned eqClass = boundedClasses.findLeader(i);

// If the class already has a bound continue.
if (bounds.contains(eqClass))
continue;

// Initialize the bound.
Variable *bound = &v;

LLVM_DEBUG({
DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
<< "`\n";
});

// Check against the other variables.
for (size_t j = i + 1; j < variables.size(); ++j) {
unsigned jEqClass = boundedClasses.findLeader(j);
// Skip if the class is the same.
if (jEqClass == eqClass)
continue;

// Get the bound of the equivalence class or itself.
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);

LLVM_DEBUG({
DBGS() << "- comparing with variable: #" << jEqClass
<< ", with map: " << nv->getMap() << "\n";
});

// Compare the variables.
FailureOr<bool> cmpResult =
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);

// The variables cannot be compared.
if (failed(cmpResult)) {
LLVM_DEBUG({
DBGS() << "-- classes: #" << i << ", #" << jEqClass
<< " cannot be merged\n";
});
continue;
}

// Join the equivalent classes and update the bound if necessary.
LLVM_DEBUG({
DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
<< ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
});
if (*cmpResult) {
boundedClasses.join(eqClass, jEqClass);
} else {
// In this case we have lhs > rhs if isMin == true, or lhs < rhs if
// isMin == false.
bound = nv;
boundedClasses.join(eqClass, jEqClass);
}
}
bounds[boundedClasses.findLeader(i)] = bound;
}

// Return if there's no simplification.
if (bounds.size() >= affineMap.getNumResults()) {
LLVM_DEBUG(
{ DBGS() << "- the affine operation couldn't get simplified\n"; });
return false;
}

// Construct the new affine affineMap.
SmallVector<AffineExpr> results;
results.reserve(bounds.size());
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));

affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
results, rewriter.getContext());

// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
return true;
}

bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
return simplifyAffineMinMaxOp(rewriter, op);
}

bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
return simplifyAffineMinMaxOp(rewriter, op);
}

LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
ArrayRef<Operation *> ops,
bool *modified) {
bool changed = false;
for (Operation *op : ops) {
if (auto minOp = dyn_cast<AffineMinOp>(op))
changed = simplifyAffineMinOp(rewriter, minOp) || changed;
else if (auto maxOp = cast<AffineMaxOp>(op))
changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
}
RewritePatternSet patterns(rewriter.getContext());
AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (modified)
*modified = changed;
// Canonicalize to a fixpoint.
if (failed(applyOpPatternsGreedily(
ops, frozenPatterns,
GreedyRewriteConfig()
.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
&changed))) {
return failure();
}
if (modified)
*modified = changed;
return success();
}
Loading
Loading