-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Remove vector.extractelement/insertelement
from sparse vectorizer
#143270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Remove vector.extractelement/insertelement
from sparse vectorizer
#143270
Conversation
…se vectorizer This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops. It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of `vector.extractelement` and `vector.insertelement`.
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR is part of the last step to remove Full diff: https://github.com/llvm/llvm-project/pull/143270.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
// Initialize reduction vector to: | 0 | .. | 0 | r |
- return rewriter.create<vector::InsertElementOp>(
+ return rewriter.create<vector::InsertOp>(
loc, r, constantZero(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::MUL:
// Initialize reduction vector to: | 1 | .. | 1 | r |
- return rewriter.create<vector::InsertElementOp>(
+ return rewriter.create<vector::InsertOp>(
loc, r, constantOne(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
const VL vl;
};
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+ Value inp) {
+ if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+ if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+ if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+ rewriter.replaceOp(op, redOp.getVector());
+ return success();
+ }
+ }
+ }
+ return failure();
+}
+
/// Reduction chain cleanup.
/// v = for { }
-/// s = vsum(v) v = for { }
-/// u = expand(s) -> for (v) { }
+/// s = vsum(v) v = for { }
+/// u = broadcast(s) -> for (v) { }
/// for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
public:
- using OpRewritePattern<VectorOp>::OpRewritePattern;
+ using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(VectorOp op,
+ LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
- Value inp = op.getSource();
- if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
- if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
- if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
- rewriter.replaceOp(op, redOp.getVector());
- return success();
- }
- }
- }
- return failure();
+ return cleanReducChain(rewriter, op, op.getSource());
}
};
+/// Reduction chain cleanup.
+/// v = for { }
+/// s = vsum(v) v = for { }
+/// u = insert(s) -> for (v) { }
+/// for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertOp op,
+ PatternRewriter &rewriter) const override {
+ return cleanReducChain(rewriter, op, op.getValueToStore());
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
vector::populateVectorStepLoweringPatterns(patterns);
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
- patterns.add<ReducChainRewriter<vector::InsertElementOp>,
- ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+ patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
// CHECK-NOVEC: }
//
// CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC: vector.insertelement
+// CHECK-VEC: vector.insert
// CHECK-VEC: scf.for
// CHECK-VEC: vector.create_mask
// CHECK-VEC: vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-IDX32-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16-IDX32: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16-IDX32: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC4-SVE: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK: %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
// CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
|
@llvm/pr-subscribers-mlir-sparse Author: Diego Caballero (dcaballe) ChangesThis PR is part of the last step to remove Full diff: https://github.com/llvm/llvm-project/pull/143270.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
// Initialize reduction vector to: | 0 | .. | 0 | r |
- return rewriter.create<vector::InsertElementOp>(
+ return rewriter.create<vector::InsertOp>(
loc, r, constantZero(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::MUL:
// Initialize reduction vector to: | 1 | .. | 1 | r |
- return rewriter.create<vector::InsertElementOp>(
+ return rewriter.create<vector::InsertOp>(
loc, r, constantOne(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
const VL vl;
};
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+ Value inp) {
+ if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+ if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+ if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+ rewriter.replaceOp(op, redOp.getVector());
+ return success();
+ }
+ }
+ }
+ return failure();
+}
+
/// Reduction chain cleanup.
/// v = for { }
-/// s = vsum(v) v = for { }
-/// u = expand(s) -> for (v) { }
+/// s = vsum(v) v = for { }
+/// u = broadcast(s) -> for (v) { }
/// for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
public:
- using OpRewritePattern<VectorOp>::OpRewritePattern;
+ using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(VectorOp op,
+ LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
- Value inp = op.getSource();
- if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
- if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
- if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
- rewriter.replaceOp(op, redOp.getVector());
- return success();
- }
- }
- }
- return failure();
+ return cleanReducChain(rewriter, op, op.getSource());
}
};
+/// Reduction chain cleanup.
+/// v = for { }
+/// s = vsum(v) v = for { }
+/// u = insert(s) -> for (v) { }
+/// for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertOp op,
+ PatternRewriter &rewriter) const override {
+ return cleanReducChain(rewriter, op, op.getValueToStore());
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
vector::populateVectorStepLoweringPatterns(patterns);
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
- patterns.add<ReducChainRewriter<vector::InsertElementOp>,
- ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+ patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
// CHECK-NOVEC: }
//
// CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC: vector.insertelement
+// CHECK-VEC: vector.insert
// CHECK-VEC: scf.for
// CHECK-VEC: vector.create_mask
// CHECK-VEC: vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-IDX32-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16-IDX32: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16-IDX32: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC4-SVE: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK: %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
// CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…se vectorizer (llvm#143270) This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops. RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of `vector.extractelement` and `vector.insertelement`.
This PR is part of the last step to remove
vector.extractelement
andvector.insertelement
ops (RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops).It updates the Sparse Vectorizer to use
vector.extract
andvector.insert
instead ofvector.extractelement
andvector.insertelement
.