@@ -159,25 +159,6 @@ static scf::YieldOp generateYieldAndOrReduceToScfLoop(OpBuilder builder,
159
159
return output;
160
160
}
161
161
162
- // Replace async op with wait_all op
163
- static air::WaitAllOp replaceAsyncOpWithWaitAll (OpBuilder builder,
164
- IRMapping &remap, Operation *op,
165
- bool cloneDepList = true ) {
166
- assert (air::isAsyncOp (op));
167
- SmallVector<Value> dep_list_remap;
168
- if (cloneDepList) {
169
- for (auto dep : air::getAsyncDependenciesFromOp (op)) {
170
- dep_list_remap.push_back (remap.lookupOrDefault (dep));
171
- }
172
- }
173
- auto wa_op = builder.create <air::WaitAllOp>(
174
- builder.getUnknownLoc (), air::AsyncTokenType::get (op->getContext ()),
175
- dep_list_remap);
176
- wa_op->setAttr (" hoist" , StringAttr::get (op->getContext (), " dep" ));
177
- remap.map (air::getAsyncTokenFromOp (op), wa_op.getAsyncToken ());
178
- return wa_op;
179
- }
180
-
181
162
// Clone affine if's block with remap
182
163
static SmallVector<Operation *>
183
164
replaceAffineIfOpWithChannelOpAndClone (OpBuilder builder, IRMapping &remap,
@@ -240,9 +221,12 @@ cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op,
240
221
SmallVector<Operation *> clonedOps;
241
222
for (Operation &o : blk->without_terminator ()) {
242
223
if (!o.hasAttr (" hoist" )) {
243
- if (air::isAsyncOp (&o))
244
- clonedOps.push_back (
245
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
224
+ if (air::isAsyncOp (&o)) {
225
+ auto wa_op =
226
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
227
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
228
+ clonedOps.push_back (wa_op);
229
+ }
246
230
continue ;
247
231
}
248
232
@@ -255,8 +239,10 @@ cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op,
255
239
" internalGetPut" ) {
256
240
// Found channel op labelled as "internalGetPut", which shouldn't be
257
241
// hoisted
258
- clonedOps.push_back (
259
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
242
+ auto wa_op =
243
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
244
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
245
+ clonedOps.push_back (wa_op);
260
246
} else {
261
247
clonedOps.push_back (builder.clone (o, remap));
262
248
}
@@ -268,13 +254,19 @@ cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op,
268
254
} else if (auto dma_op = dyn_cast<air::DmaMemcpyNdOp>(o)) {
269
255
if (o.hasAttr (" loop-carried-dep" ))
270
256
clonedOps.push_back (builder.clone (o, remap));
271
- else
272
- clonedOps.push_back (
273
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
257
+ else {
258
+ auto wa_op =
259
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
260
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
261
+ clonedOps.push_back (wa_op);
262
+ }
274
263
} else if (!air::isPure (&o) && !isa<air::WaitAllOp>(o)) {
275
- if (air::isAsyncOp (&o))
276
- clonedOps.push_back (
277
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
264
+ if (air::isAsyncOp (&o)) {
265
+ auto wa_op =
266
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
267
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
268
+ clonedOps.push_back (wa_op);
269
+ }
278
270
} else {
279
271
clonedOps.push_back (builder.clone (o, remap));
280
272
}
@@ -665,9 +657,12 @@ static void HoistingAffineIf(affine::AffineIfOp op) {
665
657
SmallVector<Operation *> clonedOps;
666
658
for (Operation &o : blk->without_terminator ()) {
667
659
if (!o.hasAttr (" hoist" )) {
668
- if (air::isAsyncOp (&o))
669
- clonedOps.push_back (
670
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
660
+ if (air::isAsyncOp (&o)) {
661
+ auto wa_op =
662
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
663
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
664
+ clonedOps.push_back (wa_op);
665
+ }
671
666
continue ;
672
667
}
673
668
@@ -685,21 +680,29 @@ static void HoistingAffineIf(affine::AffineIfOp op) {
685
680
.str () == " internalGetPut" ) {
686
681
// Found channel op labelled as "internalGetPut", which shouldn't be
687
682
// hoisted
688
- clonedOps.push_back (
689
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
683
+ auto wa_op =
684
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
685
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
686
+ clonedOps.push_back (wa_op);
690
687
} else {
691
688
clonedOps.push_back (builder.clone (o, remap));
692
689
}
693
690
} else if (auto dma_op = dyn_cast<air::DmaMemcpyNdOp>(o)) {
694
691
if (o.hasAttr (" loop-carried-dep" ))
695
692
clonedOps.push_back (builder.clone (o, remap));
696
- else
697
- clonedOps.push_back (
698
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
693
+ else {
694
+ auto wa_op =
695
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
696
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
697
+ clonedOps.push_back (wa_op);
698
+ }
699
699
} else if (!air::isPure (&o) && !isa<air::WaitAllOp>(o)) {
700
- if (air::isAsyncOp (&o))
701
- clonedOps.push_back (
702
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
700
+ if (air::isAsyncOp (&o)) {
701
+ auto wa_op =
702
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
703
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
704
+ clonedOps.push_back (wa_op);
705
+ }
703
706
} else {
704
707
clonedOps.push_back (builder.clone (o, remap));
705
708
}
@@ -843,9 +846,12 @@ class AIRDmaToAIRChannelConversion
843
846
SmallVector<Operation *> clonedOps;
844
847
for (Operation &o : blk->without_terminator ()) {
845
848
if (!o.hasAttr (" hoist" )) {
846
- if (air::isAsyncOp (&o))
847
- clonedOps.push_back (
848
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
849
+ if (air::isAsyncOp (&o)) {
850
+ auto wa_op =
851
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
852
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
853
+ clonedOps.push_back (wa_op);
854
+ }
849
855
continue ;
850
856
}
851
857
if (auto child_for_op = dyn_cast<LoopLikeOpInterface>(o)) {
@@ -858,15 +864,20 @@ class AIRDmaToAIRChannelConversion
858
864
.str () == " internalGetPut" ) {
859
865
// Found channel op labelled as "internalGetPut", which
860
866
// shouldn't be hoisted
861
- clonedOps.push_back (
862
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
867
+ auto wa_op =
868
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
869
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
870
+ clonedOps.push_back (wa_op);
863
871
} else {
864
872
clonedOps.push_back (builder.clone (o, remap));
865
873
}
866
874
} else if (!air::isPure (&o) && !isa<air::WaitAllOp>(o)) {
867
- if (air::isAsyncOp (&o))
868
- clonedOps.push_back (
869
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
875
+ if (air::isAsyncOp (&o)) {
876
+ auto wa_op =
877
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
878
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
879
+ clonedOps.push_back (wa_op);
880
+ }
870
881
} else {
871
882
clonedOps.push_back (builder.clone (o, remap));
872
883
}
@@ -1242,9 +1253,12 @@ class AIRDemoteDmaToAIRHierarchyConversion
1242
1253
SmallVector<Operation *> clonedOps;
1243
1254
for (Operation &o : blk->without_terminator ()) {
1244
1255
if (!o.hasAttr (" hoist" )) {
1245
- if (air::isAsyncOp (&o))
1246
- clonedOps.push_back (
1247
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
1256
+ if (air::isAsyncOp (&o)) {
1257
+ auto wa_op =
1258
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
1259
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
1260
+ clonedOps.push_back (wa_op);
1261
+ }
1248
1262
continue ;
1249
1263
}
1250
1264
@@ -1254,9 +1268,12 @@ class AIRDemoteDmaToAIRHierarchyConversion
1254
1268
} else if (auto memcpy_op = dyn_cast<air::MemcpyInterface>(o)) {
1255
1269
clonedOps.push_back (builder.clone (o, remap));
1256
1270
} else if (!air::isPure (&o) && !isa<air::WaitAllOp>(o)) {
1257
- if (air::isAsyncOp (&o))
1258
- clonedOps.push_back (
1259
- replaceAsyncOpWithWaitAll (builder, remap, &o, false ));
1271
+ if (air::isAsyncOp (&o)) {
1272
+ auto wa_op =
1273
+ air::replaceAsyncOpWithWaitAll (builder, remap, &o, false );
1274
+ wa_op->setAttr (" hoist" , StringAttr::get (o.getContext (), " dep" ));
1275
+ clonedOps.push_back (wa_op);
1276
+ }
1260
1277
} else {
1261
1278
clonedOps.push_back (builder.clone (o, remap));
1262
1279
}
0 commit comments