@@ -196,33 +196,84 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
196
196
template <class OpT >
197
197
static LogicalResult CanonicalizeAsyncOpDeps (OpT op,
198
198
PatternRewriter &rewriter) {
199
-
200
- SmallVector<Value> depsOfDeps;
201
- for (auto v : op.getAsyncDependencies ()) {
202
- if (auto asyncOperand =
203
- dyn_cast_if_present<AsyncOpInterface>(v.getDefiningOp ())) {
204
- auto deps = asyncOperand.getAsyncDependencies ();
205
- depsOfDeps.append (deps.begin (), deps.end ());
199
+ auto getMemrefsFromVec = [](SmallVector<Value> vec) {
200
+ SmallVector<Value> memrefs;
201
+ for (auto v : vec)
202
+ if (isa<MemRefType>(v.getType ()))
203
+ memrefs.push_back (v);
204
+ return memrefs;
205
+ };
206
+ auto getAllMemrefsTouchedbyOp = [getMemrefsFromVec](Operation *o) {
207
+ llvm::SetVector<Value> memrefs;
208
+ SmallVector<Value> vals = o->getOperands ();
209
+ vals.insert (vals.end (), o->getResults ().begin (), o->getResults ().end ());
210
+ SmallVector<Region *> regions;
211
+ for (auto ®ion : o->getRegions ())
212
+ regions.push_back (®ion);
213
+ // If air.wait_all, then we analyze the dependency by collecting all
214
+ // operations that depend on it.
215
+ auto waitAllOp = dyn_cast_if_present<air::WaitAllOp>(o);
216
+ if (waitAllOp && waitAllOp.getAsyncToken ()) {
217
+ for (auto user : waitAllOp.getAsyncToken ().getUsers ()) {
218
+ vals.insert (vals.end (), user->getOperands ().begin (),
219
+ user->getOperands ().end ());
220
+ vals.insert (vals.end (), user->getResults ().begin (),
221
+ user->getResults ().end ());
222
+ for (auto ®ion : user->getRegions ())
223
+ regions.push_back (®ion);
224
+ }
206
225
}
207
- }
226
+ auto memrefvals = getMemrefsFromVec (vals);
227
+ memrefs.insert (memrefvals.begin (), memrefvals.end ());
228
+ for (auto region : regions) {
229
+ llvm::SetVector<Value> usedVals;
230
+ getUsedValuesDefinedAbove (*region, usedVals);
231
+ auto usedMemrefs = getMemrefsFromVec (usedVals.takeVector ());
232
+ memrefs.insert (usedMemrefs.begin (), usedMemrefs.end ());
233
+ }
234
+ return memrefs;
235
+ };
236
+ auto memrefsTouchedByOp = getAllMemrefsTouchedbyOp (op.getOperation ());
208
237
// make a list of new async token operands
209
- SmallVector <Value> newAsyncDeps;
238
+ llvm::SetVector <Value> newAsyncDeps; // don't include duplicates
210
239
for (auto v : op.getAsyncDependencies ()) {
211
- // don't include duplicates
212
- if (std::find (std::begin (newAsyncDeps), std::end (newAsyncDeps), v) !=
213
- std::end (newAsyncDeps))
214
- continue ;
215
240
// don't include wait_all ops with no operands
216
241
if (auto wa = dyn_cast_if_present<WaitAllOp>(v.getDefiningOp ()))
217
242
if (wa.getAsyncDependencies ().size () == 0 )
218
243
continue ;
219
- // don't include a dependency of another dependency
220
- if (std::find (std::begin (depsOfDeps), std::end (depsOfDeps), v) !=
221
- std::end (depsOfDeps))
222
- continue ;
223
- newAsyncDeps.push_back (v);
244
+ // don't include any wrong dependencies
245
+ if (v.getDefiningOp ()) {
246
+ auto memrefsTouchedByDefOp = getAllMemrefsTouchedbyOp (v.getDefiningOp ());
247
+ if (!memrefsTouchedByDefOp.empty () && !memrefsTouchedByOp.empty () &&
248
+ llvm::none_of (memrefsTouchedByDefOp, [&memrefsTouchedByOp](Value v) {
249
+ return llvm::is_contained (memrefsTouchedByOp, v);
250
+ })) {
251
+ continue ;
252
+ }
253
+ }
254
+ newAsyncDeps.insert (v);
224
255
}
225
256
257
+ // don't include a dependency of another dependency
258
+ auto getDepsOfDeps = [](llvm::SetVector<Value> deps) {
259
+ llvm::SetVector<Value> depsOfDeps;
260
+ for (auto v : deps) {
261
+ if (auto asyncOperand =
262
+ dyn_cast_if_present<AsyncOpInterface>(v.getDefiningOp ())) {
263
+ auto deps = asyncOperand.getAsyncDependencies ();
264
+ depsOfDeps.insert (deps.begin (), deps.end ());
265
+ }
266
+ }
267
+ return depsOfDeps;
268
+ };
269
+ llvm::SetVector<Value> erased;
270
+ for (auto v : newAsyncDeps) {
271
+ if (llvm::is_contained (getDepsOfDeps (newAsyncDeps), v))
272
+ erased.insert (v);
273
+ }
274
+ for (auto e : erased)
275
+ newAsyncDeps.remove (e);
276
+
226
277
// if the operands won't change, return
227
278
if (newAsyncDeps.size () == op.getAsyncDependencies ().size ())
228
279
return failure ();
@@ -301,77 +352,6 @@ CanonicalizeAsyncLoopCarriedDepsInRegion(OpT op, PatternRewriter &rewriter) {
301
352
return success ();
302
353
}
303
354
304
- // Break any wrong async dependencies.
305
- template <class T >
306
- static LogicalResult canonicalizeFalseDependencies (T op,
307
- PatternRewriter &rewriter) {
308
- auto asyncOp = dyn_cast_if_present<air::AsyncOpInterface>(op.getOperation ());
309
- if (!asyncOp)
310
- return failure ();
311
- if (asyncOp.getAsyncDependencies ().empty ())
312
- return failure ();
313
-
314
- auto getMemrefsFromVec = [](SmallVector<Value> vec) {
315
- SmallVector<Value> memrefs;
316
- for (auto v : vec)
317
- if (isa<MemRefType>(v.getType ()))
318
- memrefs.push_back (v);
319
- return memrefs;
320
- };
321
- auto getAllMemrefsTouchedbyOp = [getMemrefsFromVec](Operation *o) {
322
- llvm::SetVector<Value> memrefs;
323
- SmallVector<Value> vals = o->getOperands ();
324
- vals.insert (vals.end (), o->getResults ().begin (), o->getResults ().end ());
325
- SmallVector<Region *> regions;
326
- for (auto ®ion : o->getRegions ())
327
- regions.push_back (®ion);
328
- // If air.wait_all, then we analyze the dependency by collecting all
329
- // operations that depend on it.
330
- auto waitAllOp = dyn_cast_if_present<air::WaitAllOp>(o);
331
- if (waitAllOp && waitAllOp.getAsyncToken ()) {
332
- for (auto user : waitAllOp.getAsyncToken ().getUsers ()) {
333
- vals.insert (vals.end (), user->getOperands ().begin (),
334
- user->getOperands ().end ());
335
- vals.insert (vals.end (), user->getResults ().begin (),
336
- user->getResults ().end ());
337
- for (auto ®ion : user->getRegions ())
338
- regions.push_back (®ion);
339
- }
340
- }
341
- auto memrefvals = getMemrefsFromVec (vals);
342
- memrefs.insert (memrefvals.begin (), memrefvals.end ());
343
- for (auto region : regions) {
344
- llvm::SetVector<Value> usedVals;
345
- getUsedValuesDefinedAbove (*region, usedVals);
346
- auto usedMemrefs = getMemrefsFromVec (usedVals.takeVector ());
347
- memrefs.insert (usedMemrefs.begin (), usedMemrefs.end ());
348
- }
349
- return memrefs;
350
- };
351
-
352
- auto memrefsTouchedByOp = getAllMemrefsTouchedbyOp (op.getOperation ());
353
- if (memrefsTouchedByOp.empty ())
354
- return failure ();
355
- SmallVector<Value> depList = asyncOp.getAsyncDependencies ();
356
- for (int i = depList.size () - 1 ; i >= 0 ; i--) {
357
- auto tokDefOp = depList[i].getDefiningOp ();
358
- if (!tokDefOp)
359
- continue ;
360
- auto memrefsTouchedByDefOp = getAllMemrefsTouchedbyOp (tokDefOp);
361
- if (memrefsTouchedByDefOp.empty ())
362
- continue ;
363
- if (llvm::none_of (memrefsTouchedByDefOp, [&memrefsTouchedByOp](Value v) {
364
- return llvm::is_contained (memrefsTouchedByOp, v);
365
- })) {
366
- auto newOp = rewriter.clone (*op);
367
- dyn_cast<air::AsyncOpInterface>(newOp).eraseAsyncDependency (i);
368
- rewriter.replaceOp (op, newOp);
369
- return success ();
370
- }
371
- }
372
- return failure ();
373
- }
374
-
375
355
//
376
356
// LaunchOp
377
357
//
@@ -587,7 +567,6 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
587
567
patterns.add (canonicalizeHierarchyOpArgs<LaunchOp>);
588
568
patterns.add (CanonicalizeAsyncOpDeps<LaunchOp>);
589
569
patterns.add (CanonicalizeAsyncLoopCarriedDepsInRegion<LaunchOp>);
590
- patterns.add (canonicalizeFalseDependencies<LaunchOp>);
591
570
}
592
571
593
572
ArrayRef<BlockArgument> LaunchOp::getIds () {
@@ -850,7 +829,6 @@ void SegmentOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
850
829
patterns.add (canonicalizeHierarchyOpArgs<SegmentOp>);
851
830
patterns.add (CanonicalizeAsyncOpDeps<SegmentOp>);
852
831
patterns.add (CanonicalizeAsyncLoopCarriedDepsInRegion<SegmentOp>);
853
- patterns.add (canonicalizeFalseDependencies<SegmentOp>);
854
832
}
855
833
856
834
ArrayRef<BlockArgument> SegmentOp::getIds () {
@@ -1112,7 +1090,6 @@ void HerdOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1112
1090
patterns.add (canonicalizeHierarchyOpArgs<HerdOp>);
1113
1091
patterns.add (CanonicalizeAsyncOpDeps<HerdOp>);
1114
1092
patterns.add (CanonicalizeAsyncLoopCarriedDepsInRegion<HerdOp>);
1115
- patterns.add (canonicalizeFalseDependencies<HerdOp>);
1116
1093
}
1117
1094
1118
1095
ArrayRef<BlockArgument> HerdOp::getIds () {
@@ -1234,7 +1211,6 @@ void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1234
1211
patterns.add (FoldExecute);
1235
1212
patterns.add (CanonicalizeAsyncOpDeps<ExecuteOp>);
1236
1213
patterns.add (CanonicalizeAsyncLoopCarriedDepsInRegion<ExecuteOp>);
1237
- patterns.add (canonicalizeFalseDependencies<ExecuteOp>);
1238
1214
}
1239
1215
1240
1216
//
@@ -1286,7 +1262,6 @@ void WaitAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1286
1262
MLIRContext *context) {
1287
1263
patterns.add (FoldWaitAll);
1288
1264
patterns.add (CanonicalizeAsyncOpDeps<WaitAllOp>);
1289
- patterns.add (canonicalizeFalseDependencies<WaitAllOp>);
1290
1265
}
1291
1266
1292
1267
// Get strides from MemRefType.
@@ -1549,7 +1524,7 @@ void DmaMemcpyNdOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1549
1524
MLIRContext *context) {
1550
1525
patterns.add (ComposeMemrefOpOnDmaMemcpyNdSrc);
1551
1526
patterns.add (ComposeMemrefOpOnDmaMemcpyNdDst);
1552
- patterns.add (canonicalizeFalseDependencies <DmaMemcpyNdOp>);
1527
+ patterns.add (CanonicalizeAsyncOpDeps <DmaMemcpyNdOp>);
1553
1528
}
1554
1529
1555
1530
//
@@ -1593,7 +1568,6 @@ void ChannelPutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1593
1568
MLIRContext *context) {
1594
1569
patterns.add (ComposeMemrefOpOnChannelOp<ChannelPutOp>);
1595
1570
patterns.add (CanonicalizeAsyncOpDeps<ChannelPutOp>);
1596
- patterns.add (canonicalizeFalseDependencies<ChannelPutOp>);
1597
1571
}
1598
1572
1599
1573
//
@@ -1604,7 +1578,6 @@ void ChannelGetOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1604
1578
MLIRContext *context) {
1605
1579
patterns.add (ComposeMemrefOpOnChannelOp<ChannelGetOp>);
1606
1580
patterns.add (CanonicalizeAsyncOpDeps<ChannelGetOp>);
1607
- patterns.add (canonicalizeFalseDependencies<ChannelGetOp>);
1608
1581
}
1609
1582
1610
1583
//
0 commit comments