@@ -426,6 +426,107 @@ class ConvertAtenReplicationPad2dOp
426
426
};
427
427
} // namespace
428
428
429
+ namespace {
430
+
431
+ // Lower aten.replication_pad3d operator into a sequence of
432
+ // tensor.extract_slice and tensor.concat operations.
433
+ class ConvertAtenReplicationPad3dOp
434
+ : public OpConversionPattern<AtenReplicationPad3dOp> {
435
+
436
+ private:
437
+ enum sliceLoc { START = 0 , END = 1 };
438
+
439
+ Value extractSlice (ConversionPatternRewriter &rewriter, Location loc,
440
+ Value input, int64_t dimension, sliceLoc sliceLoc) const {
441
+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
442
+ int64_t inputRank = inputType.getRank ();
443
+ SmallVector<Value> inputShape = getTensorSizes (rewriter, loc, input);
444
+
445
+ SmallVector<OpFoldResult> offsets (inputRank, rewriter.getIndexAttr (0 ));
446
+ if (sliceLoc == END) {
447
+ Value dimSize = inputShape[dimension];
448
+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
449
+ Value endIdx = rewriter.create <arith::SubIOp>(loc, dimSize, one);
450
+ offsets[dimension] = getAsOpFoldResult (endIdx);
451
+ }
452
+
453
+ SmallVector<OpFoldResult> allOneStrides (inputRank,
454
+ rewriter.getIndexAttr (1 ));
455
+ SmallVector<OpFoldResult> sizes (inputRank, rewriter.getIndexAttr (0 ));
456
+ for (int i = 0 ; i < inputRank; ++i)
457
+ sizes[i] = (i == dimension) ? rewriter.getIndexAttr (1 )
458
+ : getAsOpFoldResult (inputShape[i]);
459
+
460
+ Value extractedSlice = rewriter.create <tensor::ExtractSliceOp>(
461
+ loc, input, offsets, sizes, allOneStrides);
462
+ return extractedSlice;
463
+ }
464
+
465
+ Value createTile (ConversionPatternRewriter &rewriter, Location loc,
466
+ Value slice, int64_t tileWidth, int64_t dimension) const {
467
+ SmallVector<Value> slices (tileWidth, slice);
468
+ if (tileWidth == 1 )
469
+ return slice;
470
+ return rewriter.create <tensor::ConcatOp>(loc, dimension, slices);
471
+ }
472
+
473
+ public:
474
+ using OpConversionPattern::OpConversionPattern;
475
+
476
+ LogicalResult
477
+ matchAndRewrite (AtenReplicationPad3dOp op, OpAdaptor adaptor,
478
+ ConversionPatternRewriter &rewriter) const override {
479
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
480
+ return failure ();
481
+
482
+ Location loc = op->getLoc ();
483
+ Value input = adaptor.getSelf ();
484
+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
485
+ int64_t inputRank = inputType.getRank ();
486
+ unsigned numDims = inputType.getRank ();
487
+ assert (numDims >= 2 && " Not enough input dimensions" );
488
+
489
+ SmallVector<int64_t > padInts;
490
+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts)))
491
+ return rewriter.notifyMatchFailure (
492
+ op, " only support constant int pad ranges" );
493
+
494
+ if (padInts.size () != 6 )
495
+ return rewriter.notifyMatchFailure (
496
+ op, " pad range must have exactly six values" );
497
+
498
+ Value res = input;
499
+ int64_t padIdx = 0 ;
500
+ for (int64_t dim = inputRank - 1 ; dim >= inputRank - 3 ; dim--) {
501
+ int64_t startTileWidth = padInts[padIdx++];
502
+ int64_t endTileWidth = padInts[padIdx++];
503
+
504
+ SmallVector<Value> resultParts;
505
+ if (startTileWidth > 0 ) {
506
+ Value slice = extractSlice (rewriter, loc, res, dim, sliceLoc::START);
507
+ Value tile = createTile (rewriter, loc, slice, startTileWidth, dim);
508
+ resultParts.push_back (tile);
509
+ }
510
+
511
+ resultParts.push_back (res);
512
+
513
+ if (endTileWidth > 0 ) {
514
+ Value slice = extractSlice (rewriter, loc, res, dim, sliceLoc::END);
515
+ Value tile = createTile (rewriter, loc, slice, endTileWidth, dim);
516
+ resultParts.push_back (tile);
517
+ }
518
+
519
+ if (resultParts.size () > 1 )
520
+ res = rewriter.create <tensor::ConcatOp>(loc, dim, resultParts);
521
+ }
522
+
523
+ Type resultType = getTypeConverter ()->convertType (op.getType ());
524
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, res);
525
+ return success ();
526
+ }
527
+ };
528
+
529
+ } // namespace
429
530
namespace {
430
531
// Converts constant tensor allocation like ops.
431
532
template <typename OpTy, int fillVal>
@@ -696,6 +797,8 @@ void mlir::torch::torch_to_linalg::
696
797
RewritePatternSet &patterns,
697
798
ConversionTarget &target) {
698
799
MLIRContext *context = patterns.getContext ();
800
+ target.addIllegalOp <AtenReplicationPad3dOp>();
801
+ patterns.add <ConvertAtenReplicationPad3dOp>(typeConverter, context);
699
802
target.addIllegalOp <AtenReplicationPad2dOp>();
700
803
patterns.add <ConvertAtenReplicationPad2dOp>(typeConverter, context);
701
804
target.addIllegalOp <AtenReplicationPad1dOp>();
0 commit comments