Skip to content

[mlir][Vector] Remove vector.extractelement/insertelement from sparse vectorizer #143270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
// Initialize reduction vector to: | 0 | .. | 0 | r |
return rewriter.create<vector::InsertElementOp>(
loc, r, constantZero(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
return rewriter.create<vector::InsertOp>(loc, r,
constantZero(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::MUL:
// Initialize reduction vector to: | 1 | .. | 1 | r |
return rewriter.create<vector::InsertElementOp>(
loc, r, constantOne(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
return rewriter.create<vector::InsertOp>(loc, r,
constantOne(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::AND:
case vector::CombiningKind::OR:
// Initialize reduction vector to: | r | .. | r | r |
Expand Down Expand Up @@ -628,31 +628,49 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
const VL vl;
};

static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
Value inp) {
if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
rewriter.replaceOp(op, redOp.getVector());
return success();
}
}
}
return failure();
}

/// Reduction chain cleanup.
/// v = for { }
/// s = vsum(v) v = for { }
/// u = expand(s) -> for (v) { }
/// s = vsum(v) v = for { }
/// u = broadcast(s) -> for (v) { }
/// for (u) { }
template <typename VectorOp>
struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
struct ReducChainBroadcastRewriter
: public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern<VectorOp>::OpRewritePattern;
using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;

LogicalResult matchAndRewrite(VectorOp op,
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
Value inp = op.getSource();
if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
rewriter.replaceOp(op, redOp.getVector());
return success();
}
}
}
return failure();
return cleanReducChain(rewriter, op, op.getSource());
}
};

/// Reduction chain cleanup.
/// v = for { }
/// s = vsum(v) v = for { }
/// u = insert(s) -> for (v) { }
/// for (u) { }
struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
public:
using OpRewritePattern<vector::InsertOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::InsertOp op,
PatternRewriter &rewriter) const override {
return cleanReducChain(rewriter, op, op.getValueToStore());
}
};
} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -668,6 +686,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
vector::populateVectorStepLoweringPatterns(patterns);
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
patterns.add<ReducChainRewriter<vector::InsertElementOp>,
ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
patterns.getContext());
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
// CHECK-NOVEC: }
//
// CHECK-VEC-LABEL: func.func @sum_reduction
// CHECK-VEC: vector.insertelement
// CHECK-VEC: vector.insert
// CHECK-VEC: scf.for
// CHECK-VEC: vector.create_mask
// CHECK-VEC: vector.maskedload
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SparseTensor/sparse_vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC16: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
// CHECK-VEC16: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
Expand All @@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-IDX32-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16-IDX32: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16-IDX32: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
Expand All @@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC4-SVE: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
// CHECK-VEC4-SVE: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
// CHECK-VEC4-SVE: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
// CHECK: %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
// CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
Expand Down Expand Up @@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
Expand Down Expand Up @@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
Expand Down Expand Up @@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
Expand Down Expand Up @@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
Expand Down
Loading