Skip to content

Commit 383a75d

Browse files
authoredJan 6, 2025··
Remove canonicalizeFalseDependencies; move canonicalizeFalseDependencies to CanonicalizeAsyncOpDeps (Xilinx#844)
1 parent 71c2c9b commit 383a75d

File tree

1 file changed

+70
-97
lines changed

1 file changed

+70
-97
lines changed
 

‎mlir/lib/Dialect/AIR/IR/AIRDialect.cpp

+70-97
Original file line numberDiff line numberDiff line change
@@ -196,33 +196,84 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
196196
template <class OpT>
197197
static LogicalResult CanonicalizeAsyncOpDeps(OpT op,
198198
PatternRewriter &rewriter) {
199-
200-
SmallVector<Value> depsOfDeps;
201-
for (auto v : op.getAsyncDependencies()) {
202-
if (auto asyncOperand =
203-
dyn_cast_if_present<AsyncOpInterface>(v.getDefiningOp())) {
204-
auto deps = asyncOperand.getAsyncDependencies();
205-
depsOfDeps.append(deps.begin(), deps.end());
199+
auto getMemrefsFromVec = [](SmallVector<Value> vec) {
200+
SmallVector<Value> memrefs;
201+
for (auto v : vec)
202+
if (isa<MemRefType>(v.getType()))
203+
memrefs.push_back(v);
204+
return memrefs;
205+
};
206+
auto getAllMemrefsTouchedbyOp = [getMemrefsFromVec](Operation *o) {
207+
llvm::SetVector<Value> memrefs;
208+
SmallVector<Value> vals = o->getOperands();
209+
vals.insert(vals.end(), o->getResults().begin(), o->getResults().end());
210+
SmallVector<Region *> regions;
211+
for (auto &region : o->getRegions())
212+
regions.push_back(&region);
213+
// If air.wait_all, then we analyze the dependency by collecting all
214+
// operations that depend on it.
215+
auto waitAllOp = dyn_cast_if_present<air::WaitAllOp>(o);
216+
if (waitAllOp && waitAllOp.getAsyncToken()) {
217+
for (auto user : waitAllOp.getAsyncToken().getUsers()) {
218+
vals.insert(vals.end(), user->getOperands().begin(),
219+
user->getOperands().end());
220+
vals.insert(vals.end(), user->getResults().begin(),
221+
user->getResults().end());
222+
for (auto &region : user->getRegions())
223+
regions.push_back(&region);
224+
}
206225
}
207-
}
226+
auto memrefvals = getMemrefsFromVec(vals);
227+
memrefs.insert(memrefvals.begin(), memrefvals.end());
228+
for (auto region : regions) {
229+
llvm::SetVector<Value> usedVals;
230+
getUsedValuesDefinedAbove(*region, usedVals);
231+
auto usedMemrefs = getMemrefsFromVec(usedVals.takeVector());
232+
memrefs.insert(usedMemrefs.begin(), usedMemrefs.end());
233+
}
234+
return memrefs;
235+
};
236+
auto memrefsTouchedByOp = getAllMemrefsTouchedbyOp(op.getOperation());
208237
// make a list of new async token operands
209-
SmallVector<Value> newAsyncDeps;
238+
llvm::SetVector<Value> newAsyncDeps; // don't include duplicates
210239
for (auto v : op.getAsyncDependencies()) {
211-
// don't include duplicates
212-
if (std::find(std::begin(newAsyncDeps), std::end(newAsyncDeps), v) !=
213-
std::end(newAsyncDeps))
214-
continue;
215240
// don't include wait_all ops with no operands
216241
if (auto wa = dyn_cast_if_present<WaitAllOp>(v.getDefiningOp()))
217242
if (wa.getAsyncDependencies().size() == 0)
218243
continue;
219-
// don't include a dependency of another dependency
220-
if (std::find(std::begin(depsOfDeps), std::end(depsOfDeps), v) !=
221-
std::end(depsOfDeps))
222-
continue;
223-
newAsyncDeps.push_back(v);
244+
// don't include any wrong dependencies
245+
if (v.getDefiningOp()) {
246+
auto memrefsTouchedByDefOp = getAllMemrefsTouchedbyOp(v.getDefiningOp());
247+
if (!memrefsTouchedByDefOp.empty() && !memrefsTouchedByOp.empty() &&
248+
llvm::none_of(memrefsTouchedByDefOp, [&memrefsTouchedByOp](Value v) {
249+
return llvm::is_contained(memrefsTouchedByOp, v);
250+
})) {
251+
continue;
252+
}
253+
}
254+
newAsyncDeps.insert(v);
224255
}
225256

257+
// don't include a dependency of another dependency
258+
auto getDepsOfDeps = [](llvm::SetVector<Value> deps) {
259+
llvm::SetVector<Value> depsOfDeps;
260+
for (auto v : deps) {
261+
if (auto asyncOperand =
262+
dyn_cast_if_present<AsyncOpInterface>(v.getDefiningOp())) {
263+
auto deps = asyncOperand.getAsyncDependencies();
264+
depsOfDeps.insert(deps.begin(), deps.end());
265+
}
266+
}
267+
return depsOfDeps;
268+
};
269+
llvm::SetVector<Value> erased;
270+
for (auto v : newAsyncDeps) {
271+
if (llvm::is_contained(getDepsOfDeps(newAsyncDeps), v))
272+
erased.insert(v);
273+
}
274+
for (auto e : erased)
275+
newAsyncDeps.remove(e);
276+
226277
// if the operands won't change, return
227278
if (newAsyncDeps.size() == op.getAsyncDependencies().size())
228279
return failure();
@@ -301,77 +352,6 @@ CanonicalizeAsyncLoopCarriedDepsInRegion(OpT op, PatternRewriter &rewriter) {
301352
return success();
302353
}
303354

304-
// Break any wrong async dependencies.
305-
template <class T>
306-
static LogicalResult canonicalizeFalseDependencies(T op,
307-
PatternRewriter &rewriter) {
308-
auto asyncOp = dyn_cast_if_present<air::AsyncOpInterface>(op.getOperation());
309-
if (!asyncOp)
310-
return failure();
311-
if (asyncOp.getAsyncDependencies().empty())
312-
return failure();
313-
314-
auto getMemrefsFromVec = [](SmallVector<Value> vec) {
315-
SmallVector<Value> memrefs;
316-
for (auto v : vec)
317-
if (isa<MemRefType>(v.getType()))
318-
memrefs.push_back(v);
319-
return memrefs;
320-
};
321-
auto getAllMemrefsTouchedbyOp = [getMemrefsFromVec](Operation *o) {
322-
llvm::SetVector<Value> memrefs;
323-
SmallVector<Value> vals = o->getOperands();
324-
vals.insert(vals.end(), o->getResults().begin(), o->getResults().end());
325-
SmallVector<Region *> regions;
326-
for (auto &region : o->getRegions())
327-
regions.push_back(&region);
328-
// If air.wait_all, then we analyze the dependency by collecting all
329-
// operations that depend on it.
330-
auto waitAllOp = dyn_cast_if_present<air::WaitAllOp>(o);
331-
if (waitAllOp && waitAllOp.getAsyncToken()) {
332-
for (auto user : waitAllOp.getAsyncToken().getUsers()) {
333-
vals.insert(vals.end(), user->getOperands().begin(),
334-
user->getOperands().end());
335-
vals.insert(vals.end(), user->getResults().begin(),
336-
user->getResults().end());
337-
for (auto &region : user->getRegions())
338-
regions.push_back(&region);
339-
}
340-
}
341-
auto memrefvals = getMemrefsFromVec(vals);
342-
memrefs.insert(memrefvals.begin(), memrefvals.end());
343-
for (auto region : regions) {
344-
llvm::SetVector<Value> usedVals;
345-
getUsedValuesDefinedAbove(*region, usedVals);
346-
auto usedMemrefs = getMemrefsFromVec(usedVals.takeVector());
347-
memrefs.insert(usedMemrefs.begin(), usedMemrefs.end());
348-
}
349-
return memrefs;
350-
};
351-
352-
auto memrefsTouchedByOp = getAllMemrefsTouchedbyOp(op.getOperation());
353-
if (memrefsTouchedByOp.empty())
354-
return failure();
355-
SmallVector<Value> depList = asyncOp.getAsyncDependencies();
356-
for (int i = depList.size() - 1; i >= 0; i--) {
357-
auto tokDefOp = depList[i].getDefiningOp();
358-
if (!tokDefOp)
359-
continue;
360-
auto memrefsTouchedByDefOp = getAllMemrefsTouchedbyOp(tokDefOp);
361-
if (memrefsTouchedByDefOp.empty())
362-
continue;
363-
if (llvm::none_of(memrefsTouchedByDefOp, [&memrefsTouchedByOp](Value v) {
364-
return llvm::is_contained(memrefsTouchedByOp, v);
365-
})) {
366-
auto newOp = rewriter.clone(*op);
367-
dyn_cast<air::AsyncOpInterface>(newOp).eraseAsyncDependency(i);
368-
rewriter.replaceOp(op, newOp);
369-
return success();
370-
}
371-
}
372-
return failure();
373-
}
374-
375355
//
376356
// LaunchOp
377357
//
@@ -587,7 +567,6 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
587567
patterns.add(canonicalizeHierarchyOpArgs<LaunchOp>);
588568
patterns.add(CanonicalizeAsyncOpDeps<LaunchOp>);
589569
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<LaunchOp>);
590-
patterns.add(canonicalizeFalseDependencies<LaunchOp>);
591570
}
592571

593572
ArrayRef<BlockArgument> LaunchOp::getIds() {
@@ -850,7 +829,6 @@ void SegmentOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
850829
patterns.add(canonicalizeHierarchyOpArgs<SegmentOp>);
851830
patterns.add(CanonicalizeAsyncOpDeps<SegmentOp>);
852831
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<SegmentOp>);
853-
patterns.add(canonicalizeFalseDependencies<SegmentOp>);
854832
}
855833

856834
ArrayRef<BlockArgument> SegmentOp::getIds() {
@@ -1112,7 +1090,6 @@ void HerdOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
11121090
patterns.add(canonicalizeHierarchyOpArgs<HerdOp>);
11131091
patterns.add(CanonicalizeAsyncOpDeps<HerdOp>);
11141092
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<HerdOp>);
1115-
patterns.add(canonicalizeFalseDependencies<HerdOp>);
11161093
}
11171094

11181095
ArrayRef<BlockArgument> HerdOp::getIds() {
@@ -1234,7 +1211,6 @@ void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
12341211
patterns.add(FoldExecute);
12351212
patterns.add(CanonicalizeAsyncOpDeps<ExecuteOp>);
12361213
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<ExecuteOp>);
1237-
patterns.add(canonicalizeFalseDependencies<ExecuteOp>);
12381214
}
12391215

12401216
//
@@ -1286,7 +1262,6 @@ void WaitAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
12861262
MLIRContext *context) {
12871263
patterns.add(FoldWaitAll);
12881264
patterns.add(CanonicalizeAsyncOpDeps<WaitAllOp>);
1289-
patterns.add(canonicalizeFalseDependencies<WaitAllOp>);
12901265
}
12911266

12921267
// Get strides from MemRefType.
@@ -1549,7 +1524,7 @@ void DmaMemcpyNdOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
15491524
MLIRContext *context) {
15501525
patterns.add(ComposeMemrefOpOnDmaMemcpyNdSrc);
15511526
patterns.add(ComposeMemrefOpOnDmaMemcpyNdDst);
1552-
patterns.add(canonicalizeFalseDependencies<DmaMemcpyNdOp>);
1527+
patterns.add(CanonicalizeAsyncOpDeps<DmaMemcpyNdOp>);
15531528
}
15541529

15551530
//
@@ -1593,7 +1568,6 @@ void ChannelPutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
15931568
MLIRContext *context) {
15941569
patterns.add(ComposeMemrefOpOnChannelOp<ChannelPutOp>);
15951570
patterns.add(CanonicalizeAsyncOpDeps<ChannelPutOp>);
1596-
patterns.add(canonicalizeFalseDependencies<ChannelPutOp>);
15971571
}
15981572

15991573
//
@@ -1604,7 +1578,6 @@ void ChannelGetOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
16041578
MLIRContext *context) {
16051579
patterns.add(ComposeMemrefOpOnChannelOp<ChannelGetOp>);
16061580
patterns.add(CanonicalizeAsyncOpDeps<ChannelGetOp>);
1607-
patterns.add(canonicalizeFalseDependencies<ChannelGetOp>);
16081581
}
16091582

16101583
//

0 commit comments

Comments
 (0)
Please sign in to comment.