Skip to content

[mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface #145068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Jun 20, 2025

This commit makes the following changes:

  • Expose map and mapOperands in ValueBoundsConstraintSet::Variable, so that the class can be used by subclasses of ValueBoundsConstraintSet. Otherwise subclasses cannot access those members.

  • Add ValueBoundsConstraintSet::strongCompare. This method is similar to ValueBoundsConstraintSet::compare except that it returns false when the inverse comparison holds, and llvm::failure() if neither the relation nor its inverse relation could be proven.

  • Add simplifyAffineMinOp, simplifyAffineMaxOp, and simplifyAffineMinMaxOps to simplify those operations using ValueBoundsConstraintSet.

  • Adds the SimplifyMinMaxAffineOpsOp transform op that uses simplifyAffineMinMaxOps.

  • Add the test.value_with_bounds op to test unknown values with a min max range using ValueBoundsOpInterface.

  • Adds tests verifying the transform.

Example:

func.func @overlapping_constraints() -> (index, index) {
  %0 = test.value_with_bounds {min = 0 : index, max = 192 : index}
  %1 = test.value_with_bounds {min = 128 : index, max = 384 : index}
  %2 = test.value_with_bounds {min = 256 : index, max = 512 : index}
  %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
  %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
  return %r0, %r1 : index, index
}
// Result of applying `simplifyAffineMinMaxOps` to `func.func`
#map1 = affine_map<()[s0, s1] -> (s1, s0)>
func.func @overlapping_constraints() -> (index, index) {
  %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
  %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
  %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
  %3 = affine.min #map1()[%0, %1]
  %4 = affine.max #map1()[%1, %2]
  return %3, %4 : index, index
}

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Fabian Mora (fabianmcg)

Changes

This commit makes the following changes:

  • Expose map and mapOperands in ValueBoundsConstraintSet::Variable, so that the class can be used by subclasses of ValueBoundsConstraintSet.

  • Add ValueBoundsConstraintSet::strongCompare. This method is similar to ValueBoundsConstraintSet::compare except that it returns std::nullopt when the values cannot be compared, and false when the opposite comparison holds.

  • Add simplifyAffineMinOp, simplifyAffineMaxOp, and simplifyAffineMinMaxOps to simplify those operations using ValueBoundsConstraintSet.

  • Adds the SimplifyMinMaxAffineOpsOp transform op that uses simplifyAffineMinMaxOps.

  • Add the test.value_with_bounds op to test unknown values with a min max range using ValueBoundsOpInterface.

  • Adds tests verifying the transform.


Patch is 26.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145068.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td (+37)
  • (modified) mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h (+19)
  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+24-1)
  • (modified) mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp (+19)
  • (modified) mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp (+153)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+58-1)
  • (added) mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir (+72)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+35)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+10)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+19)
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 70b127fd063ca..4659ae28ee093 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -63,4 +63,41 @@ def SimplifyBoundedAffineOpsOp
   }];
 }
 
+def SimplifyMinMaxAffineOpsOp :
+  Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
+    TransformOpInterface,
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TransformEachOpTrait
+  ]> {
+  let description = [{
+    Simplify all the affine.min / affine.max ops being targeted or nested in the
+    target operation, using the `mlir::affine::simplifyAffineMinMaxOps`
+    transform.
+
+    Example:
+    ```
+    %0 = transform.structured.match ops{["gpu.launch", "affine.max"]} in %arg1
+    transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+    ```
+
+    #### Return modes
+
+    This transform consumes the target handle and does not produce any results.
+    This transforms never produces errors.
+
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+  let assemblyFormat = [{
+      $target attr-dict `:` type($target)
+  }];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // Affine_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 5c538d28c1835..b0578eb159c11 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -34,6 +34,8 @@ namespace affine {
 class AffineApplyOp;
 class AffineDelinearizeIndexOp;
 class AffineLinearizeIndexOp;
+class AffineMaxOp;
+class AffineMinOp;
 
 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
 /// operations.
@@ -127,6 +129,23 @@ OpFoldResult materializeComputedBound(
     OpBuilder &b, Location loc, AffineMap boundMap,
     ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
 
+/// Tries to simplify all affine min or max operations under `topOp`. The
+/// transform works by finding disjoint sets of affine result expressions
+/// bounded by a common affine expression on the min/max operation. It populates
+/// `modifiedOps` with all the operations modified by the transform/
+///
+/// In concrete terms, given an operation like:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
+/// If `d0 < 128` and `128 < s1 < s0`, the transform will update the op to:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
+void simplifyAffineMinMaxOps(RewriterBase &rewriter, Operation *topOp,
+                             SmallVectorImpl<Operation *> &modifiedOps);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 337314143c80c..39206b89ef8c6 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -135,10 +135,17 @@ class ValueBoundsConstraintSet
 
     /// Construct a variable for a map and its operands.
     Variable(AffineMap map, ArrayRef<Variable> mapOperands);
-    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+    Variable(AffineMap map, ValueRange mapOperands);
 
     MLIRContext *getContext() const { return map.getContext(); }
 
+    /// Returns the affine map.
+    AffineMap getMap() const { return map; }
+
+    /// Returns the map operands.
+    ValueDimList &getOperands() { return mapOperands; }
+    const ValueDimList &getOperands() const { return mapOperands; }
+
   private:
     friend class ValueBoundsConstraintSet;
     AffineMap map;
@@ -254,6 +261,12 @@ class ValueBoundsConstraintSet
   /// prove the relation or until it ran out of IR.
   static bool compare(const Variable &lhs, ComparisonOperator cmp,
                       const Variable &rhs);
+  /// This function is similar to `ValueBoundsConstraintSet::compare`, except
+  /// that it returns false if `!(lhs cmp rhs)`, and `std::nullopt` if the
+  /// values couldn't be compared.
+  static std::optional<bool> strongCompare(const Variable &lhs,
+                                           ComparisonOperator cmp,
+                                           const Variable &rhs);
 
   /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
@@ -327,6 +340,16 @@ class ValueBoundsConstraintSet
   /// constraints.
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
+  /// Return "true" if, based on the current state of the constraint system,
+  /// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
+  /// can be proven. Otherwise it returns `std::nullopt` meaning the values are
+  /// unordered with respect to the constraints.
+  ///
+  /// This function does not analyze any IR and does not populate any additional
+  /// constraints.
+  std::optional<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
+                                       int64_t rhsPos);
+
   /// Given an affine map with a single result (and map operands), add a new
   /// column to the constraint set that represents the result of the map.
   /// Traverse additional IR starting from the map operands as needed (as long
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index c9fe4474a68fa..6bd4dd23c7af5 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -148,6 +149,24 @@ void SimplifyBoundedAffineOpsOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// SimplifyMinMaxAffineOpsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure SimplifyMinMaxAffineOpsOp::applyToOne(
+    TransformRewriter &rewriter, Operation *target,
+    ApplyToEachResultList &results, TransformState &state) {
+  SmallVector<Operation *> modifiedOps;
+  simplifyAffineMinMaxOps(rewriter, target, modifiedOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void SimplifyMinMaxAffineOpsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 1c82822b2bd7f..c792200f4a49a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
   ReifyValueBounds.cpp
   SuperVectorize.cpp
   SimplifyAffineStructures.cpp
+  SimplifyAffineMinMax.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
new file mode 100644
index 0000000000000..0ddf2ec192a3e
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -0,0 +1,153 @@
+//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a transform to simplify mix/max affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/ADT/IntEqClasses.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-min-max"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Simplifies an affine min/max operation by proving there's a lower or upper
+/// bound.
+template <typename AffineOp>
+static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
+  using Variable = ValueBoundsConstraintSet::Variable;
+  using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
+
+  AffineMap affineMap = affineOp.getMap();
+  ValueRange operands = affineOp.getOperands();
+  static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
+
+  LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+
+  // Create a `Variable` list with values corresponding to each of the results
+  // in the affine affineMap.
+  SmallVector<Variable> variables = llvm::map_to_vector(
+      llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
+      [&](unsigned i) {
+        return Variable(affineMap.getSliceMap(i, 1), operands);
+      });
+
+  // Get the comparison operation.
+  ComparisonOperator cmpOp =
+      isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
+
+  // Find disjoint sets bounded by a common value.
+  llvm::IntEqClasses boundedClasses(variables.size());
+  DenseMap<unsigned, Variable *> bounds;
+  for (auto &&[i, v] : llvm::enumerate(variables)) {
+    unsigned eqClass = boundedClasses.findLeader(i);
+
+    // If the class already has a bound continue.
+    if (bounds.contains(eqClass))
+      continue;
+
+    // Initialize the bound.
+    Variable *bound = &v;
+
+    LLVM_DEBUG({
+      DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+             << "`\n";
+    });
+
+    // Check against the other variables.
+    for (size_t j = i + 1; j < variables.size(); ++j) {
+      unsigned jEqClass = boundedClasses.findLeader(j);
+      // Get the bound of the equivalence class or itself.
+      Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
+
+      LLVM_DEBUG({
+        DBGS() << "- comparing with variable: #" << jEqClass
+               << ", with map: " << nv->getMap() << "\n";
+      });
+
+      // Compare the variables.
+      std::optional<bool> cmpResult =
+          ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
+
+      // The variables cannot be compared.
+      if (!cmpResult) {
+        LLVM_DEBUG({
+          DBGS() << "-- classes: #" << i << ", #" << jEqClass
+                 << " cannot be merged\n";
+        });
+        continue;
+      }
+
+      // Join the equivalent classes and update the bound if necessary.
+      LLVM_DEBUG({
+        DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
+               << ", is lhs <= rhs: " << *cmpResult << "`\n";
+      });
+      if (*cmpResult) {
+        boundedClasses.join(eqClass, jEqClass);
+      } else {
+        // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
+        // isMin == false.
+        bound = nv;
+        boundedClasses.join(eqClass, jEqClass);
+      }
+    }
+    bounds[boundedClasses.findLeader(i)] = bound;
+  }
+
+  // Return if there's no simplification.
+  if (bounds.size() >= affineMap.getNumResults()) {
+    LLVM_DEBUG(
+        { DBGS() << "- the affine operation couldn't get simplified\n"; });
+    return false;
+  }
+
+  // Construct the new affine affineMap.
+  SmallVector<AffineExpr> results;
+  results.reserve(bounds.size());
+  for (auto [k, bound] : bounds)
+    results.push_back(bound->getMap().getResult(0));
+
+  affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
+                             results, rewriter.getContext());
+
+  // Update the affine op.
+  rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
+  LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+  return true;
+}
+
+bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
+  return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
+  return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+void mlir::affine::simplifyAffineMinMaxOps(
+    RewriterBase &rewriter, Operation *topOp,
+    SmallVectorImpl<Operation *> &modifiedOps) {
+  assert(topOp && "null-op");
+  topOp->walk([&](Operation *op) {
+    if (auto affineOp = dyn_cast<AffineMinOp>(op)) {
+      if (simplifyAffineMinMaxOp(rewriter, affineOp))
+        modifiedOps.push_back(op);
+    } else if (auto affineOp = dyn_cast<AffineMaxOp>(op)) {
+      if (simplifyAffineMinMaxOp(rewriter, affineOp))
+        modifiedOps.push_back(op);
+    }
+  });
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 87f883c2e6485..d7a6187cafb1e 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
 }
 
 ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
-                                             ArrayRef<Value> mapOperands)
+                                             ValueRange mapOperands)
     : Variable(map, llvm::map_to_vector(mapOperands,
                                         [](Value v) { return Variable(v); })) {}
 
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
   return isEmpty;
 }
 
+std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
+    int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
+  auto strongCmp = [&](ComparisonOperator cmp,
+                       ComparisonOperator negCmp) -> std::optional<bool> {
+    if (comparePos(lhsPos, cmp, rhsPos))
+      return true;
+    if (comparePos(lhsPos, negCmp, rhsPos))
+      return false;
+    return std::nullopt;
+  };
+  switch (cmp) {
+  case ComparisonOperator::LT:
+    return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE);
+  case ComparisonOperator::LE:
+    return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT);
+  case ComparisonOperator::GT:
+    return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE);
+  case ComparisonOperator::GE:
+    return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT);
+  case ComparisonOperator::EQ: {
+    std::optional<bool> le =
+        strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
+    if (!le)
+      return std::nullopt;
+    if (!*le)
+      return false;
+    std::optional<bool> ge =
+        strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+    if (!ge)
+      return std::nullopt;
+    if (!*ge)
+      return false;
+    return true;
+  }
+  }
+  llvm_unreachable("invalid comparison operator");
+}
+
 bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
                                                   ComparisonOperator cmp,
                                                   const Variable &rhs) {
@@ -763,6 +801,25 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
   return cstr.comparePos(lhsPos, cmp, rhsPos);
 }
 
+std::optional<bool> ValueBoundsConstraintSet::strongCompare(
+    const Variable &lhs, ComparisonOperator cmp, const Variable &rhs) {
+  int64_t lhsPos = -1, rhsPos = -1;
+  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+                           ValueBoundsConstraintSet &cstr) {
+    // Keep processing as long as lhs/rhs were not processed.
+    if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
+        size_t(rhsPos) >= cstr.positionToValueDim.size())
+      return false;
+    // Keep processing as long as the strong relation cannot be proven.
+    std::optional<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
+    return ordered ? true : false;
+  };
+  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+  lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+  rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
+  return cstr.strongComparePos(lhsPos, cmp, rhsPos);
+}
+
 FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
                                                    const Variable &var2) {
   if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
new file mode 100644
index 0000000000000..2b6de62073e99
--- /dev/null
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt  %s  --transform-interpreter | FileCheck %s
+
+// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
+// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
+// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
+
+// CHECK: @min_max_full_simplify
+func.func @min_max_full_simplify() -> (index, index) {
+  %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK-NOT: affine.min
+  // CHECK-NOT: affine.max
+  // CHECK: return %[[V0]], %[[V1]]
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @min_only_simplify
+func.func @min_only_simplify() -> (index, index) {
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
+  // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  %0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @max_only_simplify
+func.func @max_only_simplify() -> (index, index) {
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  // CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
+  %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @overlapping_constraints
+func.func @ov...
[truncated]

@fabianmcg fabianmcg requested a review from qedawkins June 20, 2025 17:35
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the ValueBoundsConstraintSet part. I did not review the other changes in-depth.

I saw that there is SimplifyAffineMinMaxOp and SimplifyBoundedAffineOpsOp. Can we combine those two with SimplifyAffineMinMax.cpp?

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Jun 20, 2025

I saw that there is SimplifyAffineMinMaxOp and SimplifyBoundedAffineOpsOp. Can we combine those two with SimplifyAffineMinMax.cpp?

I don't know if SimplifyBoundedAffineOpsOp has any users, so it didn't even occurred to me. But regardless of that it cannot be removed until we have something like:

%iA = arith.assume %i constraints min = 0, max = 128 : index

As SimplifyBoundedAffineOpsOp relies on users adding constant bounds to specific values in the transform.
In the future we could remove it.

@matthias-springer
Copy link
Member

Is IREE using SimplifyBoundedAffineOpsOp? If not, we can probably delete it. (I added this infrastructure a while ago.)

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Jun 20, 2025

Is IREE using SimplifyBoundedAffineOpsOp?

I just checked and no. But I would wait for removal until we have the assume op upstream, otherwise we could be removing functionalities without alternatives.

/// This function is similar to `ValueBoundsConstraintSet::compare`, except
/// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
/// relation nor its inverse relation could be proven.
static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the proper mathematical term "strict" comparison?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, no. Here, strong doesn't refer to the comparison operator, but rather the method guarantees compared to the ValueBoundsConstraintSet::compare method. To be specific:

compare(x, cmp, y) == true   <==>   strongCompare(x, cmp, y) == true
&&
strongCompare(x, cmp, y) == false   ==>  compare(x, cmp, y) == false
// But 
compare(x, cmp, y) == false   =/=>   strongCompare(x, cmp, y) == false

That's because compare(x, cmp, y) == false can mean the inequality couldn't be proven (not necessarily that it's false), and in those cases strongCompare returns failure.

///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict?

@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}

FailureOr<bool> ValueBoundsConstraintSet::strongComparePos(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict ?

@@ -763,14 +801,29 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
return cstr.comparePos(lhsPos, cmp, rhsPos);
}

FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict?


// Compare the variables.
FailureOr<bool> cmpResult =
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate why you need a strict comparison here?

It is not clear to me: redundant constraints can be dropped with regular comparisons so if you had to do this I am wondering if there is something off.

In particular we need: forall x f(x) <= g(x) (and not e.g. exists x such that f(x) < g(x) but I don't see evidence that this is what is being computed).

Copy link
Contributor Author

@fabianmcg fabianmcg Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because, I don't know whether you're are objecting using strongCompare or LT, GT as comparators, I'll provide 2 answers.

  1. The reason for using strongCompare is: comparison with ValueBoundsConstraintSet is a weak order. Therefore, we need to know if they do or do not compare, otherwise, we might find or miss a min when we shouldn't or should.

  2. I can switch it to LE, GE, but it shouldn't make a difference. Because if x < y, then I keep the bound in x, but if x >= y I switch the bound to y, switching to x <= y will update x >= y to x > y, so in theory the same expressions.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff @fabianmcg !

I was a little confused because I didn't see how the new ops pass their information to the analysis but then I realized it's populateBoundsForIndexValue thanks to the good design of the interface (thanks @matthias-springer :) ).

Looks good to me modulo my question regarding the need for strict comparison. If there is a simple answer I missed feel free to just land after addressing nits.

transform.apply_patterns.canonicalization
} {apply_cse} : !transform.any_op
transform.affine.simplify_min_max_affine_ops %2 : !transform.any_op
%3 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm why do you need to rematch here?
You shouldn't need to (see apply_registered_pass and apply_patterns that don't lose the handle %2, you shouldn't either).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, simplify_min_max_affine_ops always consumes and invalidates the handle. I'll update the transform op.

nicolasvasilache added a commit to nicolasvasilache/iree that referenced this pull request Jun 21, 2025
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface.
Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes
once AffineMin simplification kicks in.

This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
nicolasvasilache added a commit to nicolasvasilache/iree that referenced this pull request Jun 21, 2025
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface.
Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes
once AffineMin simplification kicks in.

This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
nicolasvasilache added a commit to nicolasvasilache/iree that referenced this pull request Jun 21, 2025
This PR makes OnlineAttention derive from IndexingMapOpInterface and make it pad with transform.structured.pad_tiling_interface.
Additionally, ensures the dynamic case pads to a constant before tiling and properly canonicalizes to constant shapes
once AffineMin simplification kicks in.

This requires integrating LLVM past llvm/llvm-project#145068 once it has landed.
nicolasvasilache added a commit that referenced this pull request Jun 21, 2025
…ransform.pad-tiling-interface

This revision introduces a simple variant of AffineMin folding in makeComposedFoldedAffineApply
and makes use of it in transform.pad-tiling-interface.
Since this version explicitly call ValueBoundsInterface, it may be too expensive and is
only activate behind a flag.
It results in better foldings when mixing tiling and padding, including with dynamic shapes.

This should be further composed with #145068 to provide full simplification and address
the remaining TODO in the test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants