Skip to content

[MLIR][Arith] add and(a, or(a,b)) folder #138998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented May 8, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 8, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: William Moses (wsmoses)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/138998.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+12)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 3b308716c84dc..7cf65cdd4f2da 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -896,6 +896,18 @@ OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
   if (Value result = foldAndIofAndI(*this))
     return result;
 
+  /// and(a, or(a, b)) -> a
+   for (int i = 0; i < 2; i++) {
+     auto a = getOperand(1 - i);
+     if (auto orOp = getOperand(i).getDefiningOp<arith::OrIOp>()) {
+       for (int j = 0; j < 2; j++) {
+         if (orOp->getOperand(j) == a) {
+           return a;
+         }
+       }
+     }
+   }
+
   return constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(),
       [](APInt a, const APInt &b) { return std::move(a) & b; });

@wsmoses wsmoses force-pushed the users/wmoses/arith branch from d1d3a1a to 6775523 Compare May 8, 2025 00:35
@krzysz00
Copy link
Contributor

krzysz00 commented May 8, 2025

Question - partly for the peanut gallery - should this be a canonicalization instead of a fold?

@@ -896,6 +896,18 @@ OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
if (Value result = foldAndIofAndI(*this))
return result;

/// and(a, or(a, b)) -> a
for (int i = 0; i < 2; i++) {
auto a = getOperand(1 - i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Spell out the full type here since it's not immediately obvious based on the RHS: https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

@@ -896,6 +896,18 @@ OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
if (Value result = foldAndIofAndI(*this))
return result;

/// and(a, or(a, b)) -> a
for (int i = 0; i < 2; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@joker-eph joker-eph Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really rather avoid raw loops entirely.

Can this just be written as:

for (Value operand : getOperands()) {
  if (auto orOp = operand..getDefiningOp<arith::OrIOp>()) {
  

(same for the second loop)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the reason that might be hard is that we need both the original operand, and also the "other" operand. In the loop form that can be written as get operand[i] and operand[1-i], and/or operand[i] and i == 0 ? operand[1] : operand[0]. Not sure how to do that as a single foreach iterator

Copy link
Collaborator

@joker-eph joker-eph Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I missed the 1-i in the code, this looks like all code obfuscation to me.

I would write this with a lambda instead, something like:

/// and(or(a, b), a) -> a
auto matchAndOr = [&] (Value lhs, Value rhs) {
  auto orOp = lhs.getDefiningOp<arith::OrIOp>();
   if (!orOp) return false;
   for (Value orOperand : orOp->getOperands())
     if (orOperand == rhs) return true;
   return false;
};

Value lhs = getOperand(0);
Value rhs = getOperand(1);
if (matchAndOr(lhs, rhs)) return rhs;

/// `and` is commutative, swap the operands: `and(a, or(a, b)) -> a`
if (matchAndOr(rhs, lhs)) return lhs;

for (int i = 0; i < 2; i++) {
auto a = getOperand(1 - i);
if (auto orOp = getOperand(i).getDefiningOp<arith::OrIOp>()) {
for (int j = 0; j < 2; j++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

@kuhar
Copy link
Member

kuhar commented Jun 2, 2025

Question - partly for the peanut gallery - should this be a canonicalization instead of a fold?

I think we converged on a policy to prefer folds when possible: https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations

Comment on lines +904 to +906
if (orOp->getOperand(j) == a) {
return a;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (orOp->getOperand(j) == a) {
return a;
}
if (orOp->getOperand(j) == a)
return a;

// CHECK: return %[[A]]
func.func @andor(%a : i32, %b : i32) -> i32 {
%c = arith.ori %a, %b : i32
%res = arith.andi %a, %b : i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be arith.andi %a, %c?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants