@@ -198,14 +198,14 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
198
198
case vector::CombiningKind::ADD:
199
199
case vector::CombiningKind::XOR:
200
200
// Initialize reduction vector to: | 0 | .. | 0 | r |
201
- return rewriter.create <vector::InsertElementOp>(
202
- loc, r, constantZero (rewriter, loc, vtp),
203
- constantIndex (rewriter, loc, 0 ));
201
+ return rewriter.create <vector::InsertOp>(loc, r,
202
+ constantZero (rewriter, loc, vtp),
203
+ constantIndex (rewriter, loc, 0 ));
204
204
case vector::CombiningKind::MUL:
205
205
// Initialize reduction vector to: | 1 | .. | 1 | r |
206
- return rewriter.create <vector::InsertElementOp>(
207
- loc, r, constantOne (rewriter, loc, vtp),
208
- constantIndex (rewriter, loc, 0 ));
206
+ return rewriter.create <vector::InsertOp>(loc, r,
207
+ constantOne (rewriter, loc, vtp),
208
+ constantIndex (rewriter, loc, 0 ));
209
209
case vector::CombiningKind::AND:
210
210
case vector::CombiningKind::OR:
211
211
// Initialize reduction vector to: | r | .. | r | r |
@@ -628,31 +628,49 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
628
628
const VL vl;
629
629
};
630
630
631
+ static LogicalResult cleanReducChain (PatternRewriter &rewriter, Operation *op,
632
+ Value inp) {
633
+ if (auto redOp = inp.getDefiningOp <vector::ReductionOp>()) {
634
+ if (auto forOp = redOp.getVector ().getDefiningOp <scf::ForOp>()) {
635
+ if (forOp->hasAttr (LoopEmitter::getLoopEmitterLoopAttrName ())) {
636
+ rewriter.replaceOp (op, redOp.getVector ());
637
+ return success ();
638
+ }
639
+ }
640
+ }
641
+ return failure ();
642
+ }
643
+
631
644
// / Reduction chain cleanup.
632
645
// / v = for { }
633
- // / s = vsum(v) v = for { }
634
- // / u = expand (s) -> for (v) { }
646
+ // / s = vsum(v) v = for { }
647
+ // / u = broadcast (s) -> for (v) { }
635
648
// / for (u) { }
636
- template < typename VectorOp>
637
- struct ReducChainRewriter : public OpRewritePattern <VectorOp > {
649
+ struct ReducChainBroadcastRewriter
650
+ : public OpRewritePattern<vector::BroadcastOp > {
638
651
public:
639
- using OpRewritePattern<VectorOp >::OpRewritePattern;
652
+ using OpRewritePattern<vector::BroadcastOp >::OpRewritePattern;
640
653
641
- LogicalResult matchAndRewrite (VectorOp op,
654
+ LogicalResult matchAndRewrite (vector::BroadcastOp op,
642
655
PatternRewriter &rewriter) const override {
643
- Value inp = op.getSource ();
644
- if (auto redOp = inp.getDefiningOp <vector::ReductionOp>()) {
645
- if (auto forOp = redOp.getVector ().getDefiningOp <scf::ForOp>()) {
646
- if (forOp->hasAttr (LoopEmitter::getLoopEmitterLoopAttrName ())) {
647
- rewriter.replaceOp (op, redOp.getVector ());
648
- return success ();
649
- }
650
- }
651
- }
652
- return failure ();
656
+ return cleanReducChain (rewriter, op, op.getSource ());
653
657
}
654
658
};
655
659
660
+ // / Reduction chain cleanup.
661
+ // / v = for { }
662
+ // / s = vsum(v) v = for { }
663
+ // / u = insert(s) -> for (v) { }
664
+ // / for (u) { }
665
+ struct ReducChainInsertRewriter : public OpRewritePattern <vector::InsertOp> {
666
+ public:
667
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
668
+
669
+ LogicalResult matchAndRewrite (vector::InsertOp op,
670
+ PatternRewriter &rewriter) const override {
671
+ return cleanReducChain (rewriter, op, op.getValueToStore ());
672
+ }
673
+ };
656
674
} // namespace
657
675
658
676
// ===----------------------------------------------------------------------===//
@@ -668,6 +686,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
668
686
vector::populateVectorStepLoweringPatterns (patterns);
669
687
patterns.add <ForOpRewriter>(patterns.getContext (), vectorLength,
670
688
enableVLAVectorization, enableSIMDIndex32);
671
- patterns.add <ReducChainRewriter<vector::InsertElementOp>,
672
- ReducChainRewriter<vector::BroadcastOp>>( patterns.getContext ());
689
+ patterns.add <ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
690
+ patterns.getContext ());
673
691
}
0 commit comments