-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][vector] Remove redundant shape_cast(shape_cast(x)) pattern #135447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
OpRewritePattern
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThis PR removes one OpRewritePattern Note that this might affect downstream users who indirectly call Full diff: https://github.com/llvm/llvm-project/pull/135447.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..ce97847172197 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -306,10 +306,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
-/// Collect a set of vector.shape_cast folding patterns.
-void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index fda3baf3aa390..68a44ea889470 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -8,7 +8,6 @@
#include <numeric>
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -577,5 +576,4 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
- populateShapeCastFoldingPatterns(patterns, benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 62dfd439b0ad1..999fb9c415886 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -976,7 +976,6 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
patterns
.add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
patterns.getContext(), benefit);
- populateShapeCastFoldingPatterns(patterns);
}
void mlir::vector::populateFlattenVectorTransferPatterns(
@@ -985,6 +984,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext(), targetVectorBitwidth, benefit);
- populateShapeCastFoldingPatterns(patterns, benefit);
populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d50d5fe96f49a..89839d0440d3c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -16,36 +16,24 @@
#include <cstdint>
#include <functional>
#include <optional>
-#include <type_traits>
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-to-vector"
@@ -71,54 +59,6 @@ static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
namespace {
-/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
-//
-// Example:
-//
-// The following MLIR with cancelling ShapeCastOps:
-//
-// %0 = source : vector<5x4x2xf32>
-// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
-// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
-// %3 = user %2 : vector<5x4x2xf32>
-//
-// Should canonicalize to the following:
-//
-// %0 = source : vector<5x4x2xf32>
-// %1 = user %0 : vector<5x4x2xf32>
-//
-struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- // Check if 'shapeCastOp' has vector source/result type.
- auto sourceVectorType =
- dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
- auto resultVectorType =
- dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
- if (!sourceVectorType || !resultVectorType)
- return failure();
-
- // Check if shape cast op source operand is also a shape cast op.
- auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
- shapeCastOp.getSource().getDefiningOp());
- if (!sourceShapeCastOp)
- return failure();
- auto operandSourceVectorType =
- cast<VectorType>(sourceShapeCastOp.getSource().getType());
- auto operandResultVectorType = sourceShapeCastOp.getType();
-
- // Check if shape cast operations invert each other.
- if (operandSourceVectorType != resultVectorType ||
- operandResultVectorType != sourceVectorType)
- return failure();
-
- rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
- return success();
- }
-};
-
/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
/// Ex:
/// ```
@@ -2113,11 +2053,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
patterns.add<FoldI1Select>(patterns.getContext(), benefit);
}
-void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit) {
- patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
-}
-
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
// TODO: Consider either:
@@ -2126,8 +2061,7 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
// * better naming to distinguish this and
// populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
- DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
- patterns.getContext(), benefit);
+ DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
|
@llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesThis PR removes one OpRewritePattern Note that this might affect downstream users who indirectly call Full diff: https://github.com/llvm/llvm-project/pull/135447.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..ce97847172197 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -306,10 +306,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
-/// Collect a set of vector.shape_cast folding patterns.
-void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index fda3baf3aa390..68a44ea889470 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -8,7 +8,6 @@
#include <numeric>
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -577,5 +576,4 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
- populateShapeCastFoldingPatterns(patterns, benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 62dfd439b0ad1..999fb9c415886 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -976,7 +976,6 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
patterns
.add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
patterns.getContext(), benefit);
- populateShapeCastFoldingPatterns(patterns);
}
void mlir::vector::populateFlattenVectorTransferPatterns(
@@ -985,6 +984,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext(), targetVectorBitwidth, benefit);
- populateShapeCastFoldingPatterns(patterns, benefit);
populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d50d5fe96f49a..89839d0440d3c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -16,36 +16,24 @@
#include <cstdint>
#include <functional>
#include <optional>
-#include <type_traits>
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-to-vector"
@@ -71,54 +59,6 @@ static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
namespace {
-/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
-//
-// Example:
-//
-// The following MLIR with cancelling ShapeCastOps:
-//
-// %0 = source : vector<5x4x2xf32>
-// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
-// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
-// %3 = user %2 : vector<5x4x2xf32>
-//
-// Should canonicalize to the following:
-//
-// %0 = source : vector<5x4x2xf32>
-// %1 = user %0 : vector<5x4x2xf32>
-//
-struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- // Check if 'shapeCastOp' has vector source/result type.
- auto sourceVectorType =
- dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
- auto resultVectorType =
- dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
- if (!sourceVectorType || !resultVectorType)
- return failure();
-
- // Check if shape cast op source operand is also a shape cast op.
- auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
- shapeCastOp.getSource().getDefiningOp());
- if (!sourceShapeCastOp)
- return failure();
- auto operandSourceVectorType =
- cast<VectorType>(sourceShapeCastOp.getSource().getType());
- auto operandResultVectorType = sourceShapeCastOp.getType();
-
- // Check if shape cast operations invert each other.
- if (operandSourceVectorType != resultVectorType ||
- operandResultVectorType != sourceVectorType)
- return failure();
-
- rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
- return success();
- }
-};
-
/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
/// Ex:
/// ```
@@ -2113,11 +2053,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
patterns.add<FoldI1Select>(patterns.getContext(), benefit);
}
-void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit) {
- patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
-}
-
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
// TODO: Consider either:
@@ -2126,8 +2061,7 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
// * better naming to distinguish this and
// populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
- DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
- patterns.getContext(), benefit);
+ DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
…vm#135447) This PR removes one OpRewritePattern `shape_cast(shape_cast(x)) -> x` that is already handled by `ShapeCastOp::fold`. Note that this might affect downstream users who indirectly call `populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit)` and then use `patterns` with a `GreedyRewriteConfig config` that has `config.fold = false`. (only user I've checked is IREE, that never uses config.fold = false).
This PR removes one OpRewritePattern
shape_cast(shape_cast(x)) -> x
that is already handled byShapeCastOp::fold
.Note that this might affect downstream users who indirectly call
populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit)
and then usepatterns
with aGreedyRewriteConfig config
that hasconfig.fold = false
. (only user I've checked is IREE, that never uses config.fold = false).