-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface #145068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ps with ValueBoundsOpInterface
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-affine Author: Fabian Mora (fabianmcg) ChangesThis commit makes the following changes:
Patch is 26.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145068.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 70b127fd063ca..4659ae28ee093 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -63,4 +63,41 @@ def SimplifyBoundedAffineOpsOp
}];
}
+def SimplifyMinMaxAffineOpsOp :
+ Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
+ TransformOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformEachOpTrait
+ ]> {
+ let description = [{
+ Simplify all the affine.min / affine.max ops being targeted or nested in the
+ target operation, using the `mlir::affine::simplifyAffineMinMaxOps`
+ transform.
+
+ Example:
+ ```
+ %0 = transform.structured.match ops{["gpu.launch", "affine.max"]} in %arg1
+ transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+ ```
+
+ #### Return modes
+
+ This transform consumes the target handle and does not produce any results.
+ This transforms never produces errors.
+
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+ let assemblyFormat = [{
+ $target attr-dict `:` type($target)
+ }];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // Affine_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 5c538d28c1835..b0578eb159c11 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -34,6 +34,8 @@ namespace affine {
class AffineApplyOp;
class AffineDelinearizeIndexOp;
class AffineLinearizeIndexOp;
+class AffineMaxOp;
+class AffineMinOp;
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
@@ -127,6 +129,23 @@ OpFoldResult materializeComputedBound(
OpBuilder &b, Location loc, AffineMap boundMap,
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
+/// Tries to simplify all affine min or max operations under `topOp`. The
+/// transform works by finding disjoint sets of affine result expressions
+/// bounded by a common affine expression on the min/max operation. It populates
+/// `modifiedOps` with all the operations modified by the transform/
+///
+/// In concrete terms, given an operation like:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
+/// If `d0 < 128` and `128 < s1 < s0`, the transform will update the op to:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
+void simplifyAffineMinMaxOps(RewriterBase &rewriter, Operation *topOp,
+ SmallVectorImpl<Operation *> &modifiedOps);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
} // namespace affine
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 337314143c80c..39206b89ef8c6 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -135,10 +135,17 @@ class ValueBoundsConstraintSet
/// Construct a variable for a map and its operands.
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
- Variable(AffineMap map, ArrayRef<Value> mapOperands);
+ Variable(AffineMap map, ValueRange mapOperands);
MLIRContext *getContext() const { return map.getContext(); }
+ /// Returns the affine map.
+ AffineMap getMap() const { return map; }
+
+ /// Returns the map operands.
+ ValueDimList &getOperands() { return mapOperands; }
+ const ValueDimList &getOperands() const { return mapOperands; }
+
private:
friend class ValueBoundsConstraintSet;
AffineMap map;
@@ -254,6 +261,12 @@ class ValueBoundsConstraintSet
/// prove the relation or until it ran out of IR.
static bool compare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);
+ /// This function is similar to `ValueBoundsConstraintSet::compare`, except
+ /// that it returns false if `!(lhs cmp rhs)`, and `std::nullopt` if the
+ /// values couldn't be compared.
+ static std::optional<bool> strongCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs);
/// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
@@ -327,6 +340,16 @@ class ValueBoundsConstraintSet
/// constraints.
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
+ /// Return "true" if, based on the current state of the constraint system,
+ /// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
+ /// can be proven. Otherwise it returns `std::nullopt` meaning the values are
+ /// unordered with respect to the constraints.
+ ///
+ /// This function does not analyze any IR and does not populate any additional
+ /// constraints.
+ std::optional<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
+ int64_t rhsPos);
+
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index c9fe4474a68fa..6bd4dd23c7af5 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -148,6 +149,24 @@ void SimplifyBoundedAffineOpsOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// SimplifyMinMaxAffineOpsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure SimplifyMinMaxAffineOpsOp::applyToOne(
+ TransformRewriter &rewriter, Operation *target,
+ ApplyToEachResultList &results, TransformState &state) {
+ SmallVector<Operation *> modifiedOps;
+ simplifyAffineMinMaxOps(rewriter, target, modifiedOps);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void SimplifyMinMaxAffineOpsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 1c82822b2bd7f..c792200f4a49a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
+ SimplifyAffineMinMax.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
new file mode 100644
index 0000000000000..0ddf2ec192a3e
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -0,0 +1,153 @@
+//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a transform to simplify mix/max affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/ADT/IntEqClasses.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-min-max"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Simplifies an affine min/max operation by proving there's a lower or upper
+/// bound.
+template <typename AffineOp>
+static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
+ using Variable = ValueBoundsConstraintSet::Variable;
+ using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
+
+ AffineMap affineMap = affineOp.getMap();
+ ValueRange operands = affineOp.getOperands();
+ static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
+
+ LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+
+ // Create a `Variable` list with values corresponding to each of the results
+ // in the affine affineMap.
+ SmallVector<Variable> variables = llvm::map_to_vector(
+ llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
+ [&](unsigned i) {
+ return Variable(affineMap.getSliceMap(i, 1), operands);
+ });
+
+ // Get the comparison operation.
+ ComparisonOperator cmpOp =
+ isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
+
+ // Find disjoint sets bounded by a common value.
+ llvm::IntEqClasses boundedClasses(variables.size());
+ DenseMap<unsigned, Variable *> bounds;
+ for (auto &&[i, v] : llvm::enumerate(variables)) {
+ unsigned eqClass = boundedClasses.findLeader(i);
+
+ // If the class already has a bound continue.
+ if (bounds.contains(eqClass))
+ continue;
+
+ // Initialize the bound.
+ Variable *bound = &v;
+
+ LLVM_DEBUG({
+ DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
+ });
+
+ // Check against the other variables.
+ for (size_t j = i + 1; j < variables.size(); ++j) {
+ unsigned jEqClass = boundedClasses.findLeader(j);
+ // Get the bound of the equivalence class or itself.
+ Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
+
+ LLVM_DEBUG({
+ DBGS() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap() << "\n";
+ });
+
+ // Compare the variables.
+ std::optional<bool> cmpResult =
+ ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
+
+ // The variables cannot be compared.
+ if (!cmpResult) {
+ LLVM_DEBUG({
+ DBGS() << "-- classes: #" << i << ", #" << jEqClass
+ << " cannot be merged\n";
+ });
+ continue;
+ }
+
+ // Join the equivalent classes and update the bound if necessary.
+ LLVM_DEBUG({
+ DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
+ << ", is lhs <= rhs: " << *cmpResult << "`\n";
+ });
+ if (*cmpResult) {
+ boundedClasses.join(eqClass, jEqClass);
+ } else {
+ // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
+ // isMin == false.
+ bound = nv;
+ boundedClasses.join(eqClass, jEqClass);
+ }
+ }
+ bounds[boundedClasses.findLeader(i)] = bound;
+ }
+
+ // Return if there's no simplification.
+ if (bounds.size() >= affineMap.getNumResults()) {
+ LLVM_DEBUG(
+ { DBGS() << "- the affine operation couldn't get simplified\n"; });
+ return false;
+ }
+
+ // Construct the new affine affineMap.
+ SmallVector<AffineExpr> results;
+ results.reserve(bounds.size());
+ for (auto [k, bound] : bounds)
+ results.push_back(bound->getMap().getResult(0));
+
+ affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
+ results, rewriter.getContext());
+
+ // Update the affine op.
+ rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
+ LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+ return true;
+}
+
+bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
+ return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
+ return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+void mlir::affine::simplifyAffineMinMaxOps(
+ RewriterBase &rewriter, Operation *topOp,
+ SmallVectorImpl<Operation *> &modifiedOps) {
+ assert(topOp && "null-op");
+ topOp->walk([&](Operation *op) {
+ if (auto affineOp = dyn_cast<AffineMinOp>(op)) {
+ if (simplifyAffineMinMaxOp(rewriter, affineOp))
+ modifiedOps.push_back(op);
+ } else if (auto affineOp = dyn_cast<AffineMaxOp>(op)) {
+ if (simplifyAffineMinMaxOp(rewriter, affineOp))
+ modifiedOps.push_back(op);
+ }
+ });
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 87f883c2e6485..d7a6187cafb1e 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
}
ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
- ArrayRef<Value> mapOperands)
+ ValueRange mapOperands)
: Variable(map, llvm::map_to_vector(mapOperands,
[](Value v) { return Variable(v); })) {}
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
+std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
+ int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
+ auto strongCmp = [&](ComparisonOperator cmp,
+ ComparisonOperator negCmp) -> std::optional<bool> {
+ if (comparePos(lhsPos, cmp, rhsPos))
+ return true;
+ if (comparePos(lhsPos, negCmp, rhsPos))
+ return false;
+ return std::nullopt;
+ };
+ switch (cmp) {
+ case ComparisonOperator::LT:
+ return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE);
+ case ComparisonOperator::LE:
+ return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT);
+ case ComparisonOperator::GT:
+ return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE);
+ case ComparisonOperator::GE:
+ return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT);
+ case ComparisonOperator::EQ: {
+ std::optional<bool> le =
+ strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
+ if (!le)
+ return std::nullopt;
+ if (!*le)
+ return false;
+ std::optional<bool> ge =
+ strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+ if (!ge)
+ return std::nullopt;
+ if (!*ge)
+ return false;
+ return true;
+ }
+ }
+ llvm_unreachable("invalid comparison operator");
+}
+
bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
ComparisonOperator cmp,
const Variable &rhs) {
@@ -763,6 +801,25 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
+std::optional<bool> ValueBoundsConstraintSet::strongCompare(
+ const Variable &lhs, ComparisonOperator cmp, const Variable &rhs) {
+ int64_t lhsPos = -1, rhsPos = -1;
+ auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // Keep processing as long as lhs/rhs were not processed.
+ if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
+ size_t(rhsPos) >= cstr.positionToValueDim.size())
+ return false;
+ // Keep processing as long as the strong relation cannot be proven.
+ std::optional<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
+ return ordered ? true : false;
+ };
+ ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+ lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+ rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
+ return cstr.strongComparePos(lhsPos, cmp, rhsPos);
+}
+
FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
const Variable &var2) {
if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
new file mode 100644
index 0000000000000..2b6de62073e99
--- /dev/null
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
+// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
+// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
+
+// CHECK: @min_max_full_simplify
+func.func @min_max_full_simplify() -> (index, index) {
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK-NOT: affine.min
+ // CHECK-NOT: affine.max
+ // CHECK: return %[[V0]], %[[V1]]
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @min_only_simplify
+func.func @min_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
+ // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ %0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @max_only_simplify
+func.func @max_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ // CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @overlapping_constraints
+func.func @ov...
[truncated]
|
mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for the ValueBoundsConstraintSet
part. I did not review the other changes in-depth.
I saw that there is SimplifyAffineMinMaxOp
and SimplifyBoundedAffineOpsOp
. Can we combine those two with SimplifyAffineMinMax.cpp
?
I don't know if %iA = arith.assume %i constraints min = 0, max = 128 : index As |
Is IREE using |
I just checked and no. But I would wait for removal until we have the |
/// This function is similar to `ValueBoundsConstraintSet::compare`, except | ||
/// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the | ||
/// relation nor its inverse relation could be proven. | ||
static llvm::FailureOr<bool> strongCompare(const Variable &lhs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't the proper mathematical term "strict" comparison?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, no. Here, strong doesn't refer to the comparison operator, but rather the method guarantees compared to the ValueBoundsConstraintSet::compare
method. To be specific:
compare(x, cmp, y) == true <==> strongCompare(x, cmp, y) == true
&&
strongCompare(x, cmp, y) == false ==> compare(x, cmp, y) == false
// But
compare(x, cmp, y) == false =/=> strongCompare(x, cmp, y) == false
That's because compare(x, cmp, y) == false
can mean the inequality couldn't be proven (not necessarily that it's false), and in those cases strongCompare
returns failure.
/// | ||
/// This function does not analyze any IR and does not populate any additional | ||
/// constraints. | ||
llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strict?
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, | |||
return isEmpty; | |||
} | |||
|
|||
FailureOr<bool> ValueBoundsConstraintSet::strongComparePos( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strict ?
@@ -763,14 +801,29 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs, | |||
return cstr.comparePos(lhsPos, cmp, rhsPos); | |||
} | |||
|
|||
FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strict?
|
||
// Compare the variables. | ||
FailureOr<bool> cmpResult = | ||
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate why you need a strict comparison here?
It is not clear to me: redundant constraints can be dropped with regular comparisons so if you had to do this I am wondering if there is something off.
In particular we need: forall x f(x) <= g(x)
(and not e.g. exists x such that f(x) < g(x)
but I don't see evidence that this is what is being computed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because, I don't know whether you're are objecting using strongCompare
or LT, GT
as comparators, I'll provide 2 answers.
-
The reason for using
strongCompare
is: comparison withValueBoundsConstraintSet
is a weak order. Therefore, we need to know if they do or do not compare, otherwise, we might find or miss a min when we shouldn't or should. -
I can switch it to
LE, GE
, but it shouldn't make a difference. Because ifx < y
, then I keep the bound inx
, but ifx >= y
I switch the bound toy
, switching tox <= y
will updatex >= y
tox > y
, so in theory the same expressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff @fabianmcg !
I was a little confused because I didn't see how the new ops pass their information to the analysis but then I realized it's populateBoundsForIndexValue thanks to the good design of the interface (thanks @matthias-springer :) ).
Looks good to me modulo my question regarding the need for strict comparison. If there is a simple answer I missed feel free to just land after addressing nits.
Co-authored-by: Nicolas Vasilache <[email protected]>
transform.apply_patterns.canonicalization | ||
} {apply_cse} : !transform.any_op | ||
transform.affine.simplify_min_max_affine_ops %2 : !transform.any_op | ||
%3 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm why do you need to rematch here?
You shouldn't need to (see apply_registered_pass
and apply_patterns
that don't lose the handle %2
, you shouldn't either).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, simplify_min_max_affine_ops
always consumes and invalidates the handle. I'll update the transform op.
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface. Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes once AffineMin simplification kicks in. This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface. Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes once AffineMin simplification kicks in. This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface. Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes once AffineMin simplification kicks in. This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
…ransform.pad-tiling-interface This revision introduces a simple variant of AffineMin folding in makeComposedFoldedAffineApply and makes use of it in transform.pad-tiling-interface. Since this version explicitly call ValueBoundsInterface, it may be too expensive and is only activate behind a flag. It results in better foldings when mixing tiling and padding, including with dynamic shapes. This should be further composed with #145068 to provide full simplification and address the remaining TODO in the test.
This commit makes the following changes:
Expose
map
andmapOperands
inValueBoundsConstraintSet::Variable
, so that the class can be used by subclasses ofValueBoundsConstraintSet
. Otherwise subclasses cannot access those members.Add
ValueBoundsConstraintSet::strongCompare
. This method is similar toValueBoundsConstraintSet::compare
except that it returns false when the inverse comparison holds, andllvm::failure()
if neither the relation nor its inverse relation could be proven.Add
simplifyAffineMinOp
,simplifyAffineMaxOp
, andsimplifyAffineMinMaxOps
to simplify those operations usingValueBoundsConstraintSet
.Adds the
SimplifyMinMaxAffineOpsOp
transform op that usessimplifyAffineMinMaxOps
.Add the
test.value_with_bounds
op to test unknown values with a min max range usingValueBoundsOpInterface
.Adds tests verifying the transform.
Example: