Skip to content

[mlir][Vector] Remove vector.extractelement/insertelement from sparse vectorizer #143270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Jun 7, 2025

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops (RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops).
It updates the Sparse Vectorizer to use vector.extract and vector.insert instead of
vector.extractelement and vector.insertelement.

…se vectorizer

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of
`vector.extractelement` and `vector.insertelement`.
@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops.
It updates the Sparse Vectorizer to use vector.extract and vector.insert instead of
vector.extractelement and vector.insertelement.


Full diff: https://github.com/llvm/llvm-project/pull/143270.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+37-20)
  • (modified) mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+5-5)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::XOR:
     // Initialize reduction vector to: | 0 | .. | 0 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantZero(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::MUL:
     // Initialize reduction vector to: | 1 | .. | 1 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantOne(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
   const VL vl;
 };
 
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+                                     Value inp) {
+  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+    if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+      if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+        rewriter.replaceOp(op, redOp.getVector());
+        return success();
+      }
+    }
+  }
+  return failure();
+}
+
 /// Reduction chain cleanup.
 ///   v = for { }
-///   s = vsum(v)               v = for { }
-///   u = expand(s)       ->    for (v) { }
+///   s = vsum(v)                  v = for { }
+///   u = broadcast(s)       ->    for (v) { }
 ///   for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
 public:
-  using OpRewritePattern<VectorOp>::OpRewritePattern;
+  using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(VectorOp op,
+  LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
-    Value inp = op.getSource();
-    if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
-      if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
-        if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
-          rewriter.replaceOp(op, redOp.getVector());
-          return success();
-        }
-      }
-    }
-    return failure();
+    return cleanReducChain(rewriter, op, op.getSource());
   }
 };
 
+/// Reduction chain cleanup.
+///   v = for { }
+///   s = vsum(v)               v = for { }
+///   u = insert(s)       ->    for (v) { }
+///   for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+  using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    return cleanReducChain(rewriter, op, op.getValueToStore());
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
   vector::populateVectorStepLoweringPatterns(patterns);
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
-  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
-               ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+  patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
 // CHECK-NOVEC:       }
 //
 // CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC:       vector.insertelement
+// CHECK-VEC:       vector.insert
 // CHECK-VEC:       scf.for
 // CHECK-VEC:         vector.create_mask
 // CHECK-VEC:         vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-IDX32-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-IDX32-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16-IDX32:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC4-SVE:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
 // CHECK-VEC4-SVE:       %[[vscale:.*]] = vector.vscale
 // CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
 // CHECK-VEC4-SVE:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
 // CHECK:               %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
 // CHECK:               scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
 // CHECK:             } attributes {"Emitted from" = "linalg.generic"}
-// CHECK:             %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK:             %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
 // CHECK:             %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
 // CHECK:               %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
 // CHECK:               %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:           %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON:           %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
 // CHECK-ON:           %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:             %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:             %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
 // CHECK-ON:  %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:  %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:  %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:  %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:  %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:  %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:    %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:    %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>

@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2025

@llvm/pr-subscribers-mlir-sparse

Author: Diego Caballero (dcaballe)

Changes

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops.
It updates the Sparse Vectorizer to use vector.extract and vector.insert instead of
vector.extractelement and vector.insertelement.


Full diff: https://github.com/llvm/llvm-project/pull/143270.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+37-20)
  • (modified) mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+5-5)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::XOR:
     // Initialize reduction vector to: | 0 | .. | 0 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantZero(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::MUL:
     // Initialize reduction vector to: | 1 | .. | 1 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantOne(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
   const VL vl;
 };
 
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+                                     Value inp) {
+  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+    if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+      if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+        rewriter.replaceOp(op, redOp.getVector());
+        return success();
+      }
+    }
+  }
+  return failure();
+}
+
 /// Reduction chain cleanup.
 ///   v = for { }
-///   s = vsum(v)               v = for { }
-///   u = expand(s)       ->    for (v) { }
+///   s = vsum(v)                  v = for { }
+///   u = broadcast(s)       ->    for (v) { }
 ///   for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
 public:
-  using OpRewritePattern<VectorOp>::OpRewritePattern;
+  using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(VectorOp op,
+  LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
-    Value inp = op.getSource();
-    if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
-      if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
-        if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
-          rewriter.replaceOp(op, redOp.getVector());
-          return success();
-        }
-      }
-    }
-    return failure();
+    return cleanReducChain(rewriter, op, op.getSource());
   }
 };
 
+/// Reduction chain cleanup.
+///   v = for { }
+///   s = vsum(v)               v = for { }
+///   u = insert(s)       ->    for (v) { }
+///   for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+  using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    return cleanReducChain(rewriter, op, op.getValueToStore());
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
   vector::populateVectorStepLoweringPatterns(patterns);
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
-  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
-               ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+  patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
 // CHECK-NOVEC:       }
 //
 // CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC:       vector.insertelement
+// CHECK-VEC:       vector.insert
 // CHECK-VEC:       scf.for
 // CHECK-VEC:         vector.create_mask
 // CHECK-VEC:         vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-IDX32-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-IDX32-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16-IDX32:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC4-SVE:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
 // CHECK-VEC4-SVE:       %[[vscale:.*]] = vector.vscale
 // CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
 // CHECK-VEC4-SVE:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
 // CHECK:               %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
 // CHECK:               scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
 // CHECK:             } attributes {"Emitted from" = "linalg.generic"}
-// CHECK:             %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK:             %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
 // CHECK:             %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
 // CHECK:               %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
 // CHECK:               %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:           %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON:           %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
 // CHECK-ON:           %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:             %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:             %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
 // CHECK-ON:  %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:  %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:  %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:  %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:  %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:  %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:    %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:    %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>

Copy link

github-actions bot commented Jun 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@dcaballe dcaballe merged commit 1ac61c8 into llvm:main Jun 12, 2025
7 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…se vectorizer (llvm#143270)

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of `vector.extractelement` and `vector.insertelement`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants