@@ -3257,178 +3257,6 @@ class AIREnforceLoopCarriedMemrefDeallocPattern
3257
3257
private:
3258
3258
};
3259
3259
3260
- // A pass which de-alias a memref with multiple channel accesses over time, into
3261
- // multiple memrefs. Note that this implementation is temporary and not generic.
3262
- // TODO: Rewrite as a graph partitioning problem.
3263
- class AIRDeAliasMemref
3264
- : public xilinx::air::impl::AIRDeAliasMemrefBase<AIRDeAliasMemref> {
3265
-
3266
- public:
3267
- AIRDeAliasMemref () = default ;
3268
- AIRDeAliasMemref (const AIRDeAliasMemref &pass) {}
3269
-
3270
- void getDependentDialects (::mlir::DialectRegistry ®istry) const override {
3271
- registry.insert <scf::SCFDialect, air::airDialect>();
3272
- }
3273
-
3274
- void runOnFunction (func::FuncOp f) {
3275
-
3276
- std::vector<memref::AllocOp> allocs;
3277
- f.walk ([&](memref::AllocOp alloc) { allocs.push_back (alloc); });
3278
-
3279
- // Count air.channel references
3280
- for (auto alloc : allocs) {
3281
- Value memref = nullptr ;
3282
- if (auto exec = alloc->getParentOfType <air::ExecuteOp>()) {
3283
- memref = exec->getResult (1 );
3284
- } else
3285
- memref = alloc.getMemref ();
3286
- std::vector<air::ChannelInterface> chan_puts_gets;
3287
- for (auto user : memref.getUsers ()) {
3288
- if (auto putget = dyn_cast<air::ChannelInterface>(user))
3289
- if (putget.getMemref () == memref)
3290
- chan_puts_gets.push_back (putget);
3291
- }
3292
-
3293
- // Partition the subgraph
3294
- std::vector<int > partition_cuts;
3295
- if (!chan_puts_gets.empty ()) {
3296
- for (unsigned i = 0 ; i < chan_puts_gets.size () - 1 ; i++) {
3297
- if (isa<air::ChannelGetOp>(chan_puts_gets[i].getOperation ()) &&
3298
- isa<air::ChannelPutOp>(chan_puts_gets[i + 1 ].getOperation ())) {
3299
- partition_cuts.push_back (i + 1 );
3300
- }
3301
- }
3302
- }
3303
-
3304
- // Allocate new memref per cut
3305
- std::vector<Operation *> new_memallocs;
3306
- for (unsigned i = 0 ; i < partition_cuts.size (); i++) {
3307
- OpBuilder builder (alloc);
3308
- Operation *new_op = nullptr ;
3309
- if (auto exec = alloc->getParentOfType <air::ExecuteOp>()) {
3310
- builder.setInsertionPoint (exec);
3311
- new_op = builder.clone (*exec.getOperation ());
3312
- } else
3313
- new_op = builder.clone (*alloc.getOperation ());
3314
- new_memallocs.push_back (new_op);
3315
-
3316
- // Create deallocs for the new memref
3317
- Value new_memref = isa<air::ExecuteOp>(new_op) ? new_op->getResult (1 )
3318
- : new_op->getResult (0 );
3319
- for (auto user : memref.getUsers ()) {
3320
- if (isa<memref::DeallocOp>(user)) {
3321
- if (isa<air::ExecuteOp>(new_op)) {
3322
- builder.setInsertionPoint (
3323
- user->getParentOfType <air::ExecuteOp>());
3324
- // Async. dealloc
3325
- auto async_exec = builder.create <xilinx::air::ExecuteOp>(
3326
- user->getLoc (), air::AsyncTokenType::get (alloc->getContext ()),
3327
- SmallVector<Value>{});
3328
- Block *async_exec_bb =
3329
- builder.createBlock (&async_exec.getRegion ());
3330
- builder.setInsertionPointToStart (async_exec_bb);
3331
- builder.create <memref::DeallocOp>(user->getLoc (), new_memref);
3332
- builder.create <air::ExecuteTerminatorOp>(user->getLoc ());
3333
- } else {
3334
- builder.setInsertionPoint (user);
3335
- // Sync. dealloc
3336
- builder.create <memref::DeallocOp>(user->getLoc (), new_memref);
3337
- }
3338
- }
3339
- }
3340
- }
3341
-
3342
- // Update references
3343
- partition_cuts.insert (partition_cuts.end (), chan_puts_gets.size ());
3344
- for (unsigned i = 0 ; i < partition_cuts.size () - 1 ; i++) {
3345
- for (int j = partition_cuts[i]; j < partition_cuts[i + 1 ]; j++) {
3346
- if (auto old_put = dyn_cast<air::ChannelPutOp>(
3347
- chan_puts_gets[j].getOperation ())) {
3348
- Value new_memref = isa<air::ExecuteOp>(new_memallocs[i])
3349
- ? new_memallocs[i]->getResult (1 )
3350
- : new_memallocs[i]->getResult (0 );
3351
- OpBuilder builder (old_put);
3352
- replaceChannelPutOp (builder, old_put, new_memref);
3353
- } else if (auto old_get = dyn_cast<air::ChannelGetOp>(
3354
- chan_puts_gets[j].getOperation ())) {
3355
- Value new_memref = isa<air::ExecuteOp>(new_memallocs[i])
3356
- ? new_memallocs[i]->getResult (1 )
3357
- : new_memallocs[i]->getResult (0 );
3358
- OpBuilder builder (old_get);
3359
- replaceChannelGetOp (builder, old_get, new_memref);
3360
- }
3361
- }
3362
- }
3363
- }
3364
- }
3365
-
3366
- void runOnOperation () override {
3367
- auto module = getOperation ();
3368
-
3369
- SmallVector<func::FuncOp, 4 > funcOps;
3370
- module.walk ([&](func::FuncOp op) { funcOps.push_back (op); });
3371
- for (auto f : funcOps) {
3372
- runOnFunction (f);
3373
- }
3374
- }
3375
-
3376
- private:
3377
- Operation *replaceChannelPutOp (OpBuilder builder, air::ChannelPutOp old,
3378
- Value new_memref) {
3379
- builder.setInsertionPoint (old);
3380
- SmallVector<Type, 1 > tys;
3381
- if (old.getAsyncToken ()) {
3382
- tys.push_back (air::AsyncTokenType::get (old->getContext ()));
3383
- }
3384
- SmallVector<Value, 4 > deps = old.getAsyncDependencies ();
3385
- auto new_op = builder.create <air::ChannelPutOp>(
3386
- old->getLoc (), tys, deps, old.getChanName (), old.getIndices (),
3387
- new_memref, old.getSrcOffsets (), old.getSrcSizes (),
3388
- old.getSrcStrides ());
3389
- if (old.getAsyncToken ()) {
3390
- old.getAsyncToken ().replaceAllUsesWith (new_op.getAsyncToken ());
3391
- // Add dependence to the new memref
3392
- new_op.addAsyncDependency (
3393
- dyn_cast<air::ExecuteOp>(new_memref.getDefiningOp ()).getAsyncToken ());
3394
- }
3395
- if (old.getId () != -1 ) {
3396
- new_op->setAttr (" id" , mlir::IntegerAttr::get (
3397
- mlir::IntegerType::get (old->getContext (), 32 ),
3398
- old.getId ()));
3399
- }
3400
- old->erase ();
3401
- return new_op.getOperation ();
3402
- }
3403
- Operation *replaceChannelGetOp (OpBuilder builder, air::ChannelGetOp old,
3404
- Value new_memref) {
3405
- builder.setInsertionPoint (old);
3406
- SmallVector<Type, 1 > tys;
3407
- if (old.getAsyncToken ()) {
3408
- tys.push_back (air::AsyncTokenType::get (old->getContext ()));
3409
- }
3410
- SmallVector<Value, 4 > deps = old.getAsyncDependencies ();
3411
- auto new_op = builder.create <air::ChannelGetOp>(
3412
- old->getLoc (), tys, deps, old.getChanName (), old.getIndices (),
3413
- new_memref, old.getDstOffsets (), old.getDstSizes (),
3414
- old.getDstStrides ());
3415
- new_op->setAttrs (old->getDiscardableAttrDictionary ());
3416
- if (old.getAsyncToken ()) {
3417
- old.getAsyncToken ().replaceAllUsesWith (new_op.getAsyncToken ());
3418
- // Add dependence to the new memref
3419
- new_op.addAsyncDependency (
3420
- dyn_cast<air::ExecuteOp>(new_memref.getDefiningOp ()).getAsyncToken ());
3421
- }
3422
- if (old.getId () != -1 ) {
3423
- new_op->setAttr (" id" , mlir::IntegerAttr::get (
3424
- mlir::IntegerType::get (old->getContext (), 32 ),
3425
- old.getId ()));
3426
- }
3427
- old->erase ();
3428
- return new_op.getOperation ();
3429
- }
3430
- };
3431
-
3432
3260
// A pass which transform multiple channel ops into one, where the data movement
3433
3261
// is time-multiplexed.
3434
3262
class AIRFuseChannels
@@ -6175,10 +6003,6 @@ std::unique_ptr<Pass> createAIREnforceLoopCarriedMemrefDeallocPattern() {
6175
6003
return std::make_unique<AIREnforceLoopCarriedMemrefDeallocPattern>();
6176
6004
}
6177
6005
6178
- std::unique_ptr<Pass> createAIRDeAliasMemref () {
6179
- return std::make_unique<AIRDeAliasMemref>();
6180
- }
6181
-
6182
6006
std::unique_ptr<Pass> createAIRFuseChannels () {
6183
6007
return std::make_unique<AIRFuseChannels>();
6184
6008
}
0 commit comments