Skip to content

Commit 71c2c9b

Browse files
authored
Remove duplicate implementations of replaceAsyncOpWithWaitAll (Xilinx#843)
1 parent 62a3e71 commit 71c2c9b

File tree

4 files changed

+97
-78
lines changed

4 files changed

+97
-78
lines changed

mlir/include/air/Util/Dependency.h

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ LogicalResult unrollScfParallel(OpBuilder builder, scf::ParallelOp par,
7777
Operation *originalChanOp, IRMapping remap);
7878
void populateAIRunrollAIRChannelPutGetInScfParallelPatterns(
7979
RewritePatternSet &patterns);
80+
air::WaitAllOp replaceAsyncOpWithWaitAll(OpBuilder builder, IRMapping &remap,
81+
Operation *op,
82+
bool cloneDepList = true);
83+
8084
//===----------------------------------------------------------------------===//
8185
// Dependency graph
8286
//===----------------------------------------------------------------------===//

mlir/lib/Transform/AIRDependencyScheduleOpt.cpp

+4-23
Original file line numberDiff line numberDiff line change
@@ -3231,25 +3231,6 @@ class AIRDeAliasMemref
32313231
}
32323232
};
32333233

3234-
// Replace async op with wait_all op
3235-
static air::WaitAllOp replaceAsyncOpWithWaitAll(OpBuilder builder,
3236-
IRMapping &remap, Operation *op,
3237-
bool cloneDepList = true) {
3238-
assert(air::isAsyncOp(op));
3239-
SmallVector<Value> dep_list_remap;
3240-
if (cloneDepList) {
3241-
for (auto dep : air::getAsyncDependenciesFromOp(op)) {
3242-
dep_list_remap.push_back(remap.lookupOrDefault(dep));
3243-
}
3244-
}
3245-
auto wa_op = builder.create<air::WaitAllOp>(
3246-
builder.getUnknownLoc(), air::AsyncTokenType::get(op->getContext()),
3247-
dep_list_remap);
3248-
wa_op->setAttr("hoist", StringAttr::get(op->getContext(), "dep"));
3249-
remap.map(air::getAsyncTokenFromOp(op), wa_op.getAsyncToken());
3250-
return wa_op;
3251-
}
3252-
32533234
// A pass which transform multiple channel ops into one, where the data movement
32543235
// is time-multiplexed.
32553236
class AIRFuseChannels
@@ -3993,7 +3974,7 @@ class AIRFuseChannels
39933974
if (air::isAsyncOp(b)) {
39943975
IRMapping waitAllRemap;
39953976
builder.setInsertionPoint(b);
3996-
auto waitAll = replaceAsyncOpWithWaitAll(builder, waitAllRemap, b);
3977+
auto waitAll = air::replaceAsyncOpWithWaitAll(builder, waitAllRemap, b);
39973978
air::getAsyncTokenFromOp(b).replaceAllUsesWith(waitAll.getAsyncToken());
39983979
}
39993980
b->erase();
@@ -4056,7 +4037,7 @@ class AIRFuseChannels
40564037
if (air::isAsyncOp(b)) {
40574038
IRMapping remap;
40584039
builder.setInsertionPoint(b);
4059-
auto waitAll = replaceAsyncOpWithWaitAll(builder, remap, b);
4040+
auto waitAll = air::replaceAsyncOpWithWaitAll(builder, remap, b);
40604041
air::getAsyncTokenFromOp(b).replaceAllUsesWith(waitAll.getAsyncToken());
40614042
}
40624043
b->erase();
@@ -5189,12 +5170,12 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter,
51895170
air::WaitAllOp waitAll = air::WaitAllOp();
51905171
if (dealloc) {
51915172
rewriter.setInsertionPoint(dealloc);
5192-
waitAll = replaceAsyncOpWithWaitAll(rewriter, remap, dealloc);
5173+
waitAll = air::replaceAsyncOpWithWaitAll(rewriter, remap, dealloc);
51935174
dealloc.getAsyncToken().replaceAllUsesWith(waitAll.getAsyncToken());
51945175
rewriter.eraseOp(dealloc);
51955176
}
51965177
rewriter.setInsertionPoint(alloc);
5197-
waitAll = replaceAsyncOpWithWaitAll(rewriter, remap, alloc);
5178+
waitAll = air::replaceAsyncOpWithWaitAll(rewriter, remap, alloc);
51985179
alloc.getAsyncToken().replaceAllUsesWith(waitAll.getAsyncToken());
51995180
rewriter.eraseOp(alloc);
52005181
}

mlir/lib/Transform/AIRDmaToChannel.cpp

+72-55
Original file line numberDiff line numberDiff line change
@@ -159,25 +159,6 @@ static scf::YieldOp generateYieldAndOrReduceToScfLoop(OpBuilder builder,
159159
return output;
160160
}
161161

162-
// Replace async op with wait_all op
163-
static air::WaitAllOp replaceAsyncOpWithWaitAll(OpBuilder builder,
164-
IRMapping &remap, Operation *op,
165-
bool cloneDepList = true) {
166-
assert(air::isAsyncOp(op));
167-
SmallVector<Value> dep_list_remap;
168-
if (cloneDepList) {
169-
for (auto dep : air::getAsyncDependenciesFromOp(op)) {
170-
dep_list_remap.push_back(remap.lookupOrDefault(dep));
171-
}
172-
}
173-
auto wa_op = builder.create<air::WaitAllOp>(
174-
builder.getUnknownLoc(), air::AsyncTokenType::get(op->getContext()),
175-
dep_list_remap);
176-
wa_op->setAttr("hoist", StringAttr::get(op->getContext(), "dep"));
177-
remap.map(air::getAsyncTokenFromOp(op), wa_op.getAsyncToken());
178-
return wa_op;
179-
}
180-
181162
// Clone affine if's block with remap
182163
static SmallVector<Operation *>
183164
replaceAffineIfOpWithChannelOpAndClone(OpBuilder builder, IRMapping &remap,
@@ -240,9 +221,12 @@ cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op,
240221
SmallVector<Operation *> clonedOps;
241222
for (Operation &o : blk->without_terminator()) {
242223
if (!o.hasAttr("hoist")) {
243-
if (air::isAsyncOp(&o))
244-
clonedOps.push_back(
245-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
224+
if (air::isAsyncOp(&o)) {
225+
auto wa_op =
226+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
227+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
228+
clonedOps.push_back(wa_op);
229+
}
246230
continue;
247231
}
248232

@@ -255,8 +239,10 @@ cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op,
255239
"internalGetPut") {
256240
// Found channel op labelled as "internalGetPut", which shouldn't be
257241
// hoisted
258-
clonedOps.push_back(
259-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
242+
auto wa_op =
243+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
244+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
245+
clonedOps.push_back(wa_op);
260246
} else {
261247
clonedOps.push_back(builder.clone(o, remap));
262248
}
@@ -268,13 +254,19 @@ cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op,
268254
} else if (auto dma_op = dyn_cast<air::DmaMemcpyNdOp>(o)) {
269255
if (o.hasAttr("loop-carried-dep"))
270256
clonedOps.push_back(builder.clone(o, remap));
271-
else
272-
clonedOps.push_back(
273-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
257+
else {
258+
auto wa_op =
259+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
260+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
261+
clonedOps.push_back(wa_op);
262+
}
274263
} else if (!air::isPure(&o) && !isa<air::WaitAllOp>(o)) {
275-
if (air::isAsyncOp(&o))
276-
clonedOps.push_back(
277-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
264+
if (air::isAsyncOp(&o)) {
265+
auto wa_op =
266+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
267+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
268+
clonedOps.push_back(wa_op);
269+
}
278270
} else {
279271
clonedOps.push_back(builder.clone(o, remap));
280272
}
@@ -665,9 +657,12 @@ static void HoistingAffineIf(affine::AffineIfOp op) {
665657
SmallVector<Operation *> clonedOps;
666658
for (Operation &o : blk->without_terminator()) {
667659
if (!o.hasAttr("hoist")) {
668-
if (air::isAsyncOp(&o))
669-
clonedOps.push_back(
670-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
660+
if (air::isAsyncOp(&o)) {
661+
auto wa_op =
662+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
663+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
664+
clonedOps.push_back(wa_op);
665+
}
671666
continue;
672667
}
673668

@@ -685,21 +680,29 @@ static void HoistingAffineIf(affine::AffineIfOp op) {
685680
.str() == "internalGetPut") {
686681
// Found channel op labelled as "internalGetPut", which shouldn't be
687682
// hoisted
688-
clonedOps.push_back(
689-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
683+
auto wa_op =
684+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
685+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
686+
clonedOps.push_back(wa_op);
690687
} else {
691688
clonedOps.push_back(builder.clone(o, remap));
692689
}
693690
} else if (auto dma_op = dyn_cast<air::DmaMemcpyNdOp>(o)) {
694691
if (o.hasAttr("loop-carried-dep"))
695692
clonedOps.push_back(builder.clone(o, remap));
696-
else
697-
clonedOps.push_back(
698-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
693+
else {
694+
auto wa_op =
695+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
696+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
697+
clonedOps.push_back(wa_op);
698+
}
699699
} else if (!air::isPure(&o) && !isa<air::WaitAllOp>(o)) {
700-
if (air::isAsyncOp(&o))
701-
clonedOps.push_back(
702-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
700+
if (air::isAsyncOp(&o)) {
701+
auto wa_op =
702+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
703+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
704+
clonedOps.push_back(wa_op);
705+
}
703706
} else {
704707
clonedOps.push_back(builder.clone(o, remap));
705708
}
@@ -843,9 +846,12 @@ class AIRDmaToAIRChannelConversion
843846
SmallVector<Operation *> clonedOps;
844847
for (Operation &o : blk->without_terminator()) {
845848
if (!o.hasAttr("hoist")) {
846-
if (air::isAsyncOp(&o))
847-
clonedOps.push_back(
848-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
849+
if (air::isAsyncOp(&o)) {
850+
auto wa_op =
851+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
852+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
853+
clonedOps.push_back(wa_op);
854+
}
849855
continue;
850856
}
851857
if (auto child_for_op = dyn_cast<LoopLikeOpInterface>(o)) {
@@ -858,15 +864,20 @@ class AIRDmaToAIRChannelConversion
858864
.str() == "internalGetPut") {
859865
// Found channel op labelled as "internalGetPut", which
860866
// shouldn't be hoisted
861-
clonedOps.push_back(
862-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
867+
auto wa_op =
868+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
869+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
870+
clonedOps.push_back(wa_op);
863871
} else {
864872
clonedOps.push_back(builder.clone(o, remap));
865873
}
866874
} else if (!air::isPure(&o) && !isa<air::WaitAllOp>(o)) {
867-
if (air::isAsyncOp(&o))
868-
clonedOps.push_back(
869-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
875+
if (air::isAsyncOp(&o)) {
876+
auto wa_op =
877+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
878+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
879+
clonedOps.push_back(wa_op);
880+
}
870881
} else {
871882
clonedOps.push_back(builder.clone(o, remap));
872883
}
@@ -1242,9 +1253,12 @@ class AIRDemoteDmaToAIRHierarchyConversion
12421253
SmallVector<Operation *> clonedOps;
12431254
for (Operation &o : blk->without_terminator()) {
12441255
if (!o.hasAttr("hoist")) {
1245-
if (air::isAsyncOp(&o))
1246-
clonedOps.push_back(
1247-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
1256+
if (air::isAsyncOp(&o)) {
1257+
auto wa_op =
1258+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
1259+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
1260+
clonedOps.push_back(wa_op);
1261+
}
12481262
continue;
12491263
}
12501264

@@ -1254,9 +1268,12 @@ class AIRDemoteDmaToAIRHierarchyConversion
12541268
} else if (auto memcpy_op = dyn_cast<air::MemcpyInterface>(o)) {
12551269
clonedOps.push_back(builder.clone(o, remap));
12561270
} else if (!air::isPure(&o) && !isa<air::WaitAllOp>(o)) {
1257-
if (air::isAsyncOp(&o))
1258-
clonedOps.push_back(
1259-
replaceAsyncOpWithWaitAll(builder, remap, &o, false));
1271+
if (air::isAsyncOp(&o)) {
1272+
auto wa_op =
1273+
air::replaceAsyncOpWithWaitAll(builder, remap, &o, false);
1274+
wa_op->setAttr("hoist", StringAttr::get(o.getContext(), "dep"));
1275+
clonedOps.push_back(wa_op);
1276+
}
12601277
} else {
12611278
clonedOps.push_back(builder.clone(o, remap));
12621279
}

mlir/lib/Util/Dependency.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,23 @@ void populateAIRunrollAIRChannelPutGetInScfParallelPatterns(
936936
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
937937
}
938938

939+
// Replace async op with wait_all op
940+
air::WaitAllOp replaceAsyncOpWithWaitAll(OpBuilder builder, IRMapping &remap,
941+
Operation *op, bool cloneDepList) {
942+
assert(air::isAsyncOp(op));
943+
SmallVector<Value> dep_list_remap;
944+
if (cloneDepList) {
945+
for (auto dep : air::getAsyncDependenciesFromOp(op)) {
946+
dep_list_remap.push_back(remap.lookupOrDefault(dep));
947+
}
948+
}
949+
auto wa_op = builder.create<air::WaitAllOp>(
950+
builder.getUnknownLoc(), air::AsyncTokenType::get(op->getContext()),
951+
dep_list_remap);
952+
remap.map(air::getAsyncTokenFromOp(op), wa_op.getAsyncToken());
953+
return wa_op;
954+
}
955+
939956
//===----------------------------------------------------------------------===//
940957
// Dependency graph
941958
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)