Skip to content

Commit 33df7ab

Browse files
authored
AIRDependency: Rewrite logic around replacing uses of async tokens to an erased (hoisted) op (Xilinx#786)
* Rewrite logic around replacing uses of async tokens to an erased (hoisted) op * Removed unused var
1 parent b55697a commit 33df7ab

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

Diff for: mlir/lib/Util/Dependency.cpp

+13-16
Original file line numberDiff line numberDiff line change
@@ -752,25 +752,22 @@ scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
752752
for (auto erase_op : target_ops) {
753753
// Reconnect returned tokens.
754754
rewriter.setInsertionPoint(erase_op);
755+
SmallVector<Value> erase_op_async_deps =
756+
air::getAsyncDependenciesFromOp(erase_op);
755757
for (auto res : erase_op->getResults()) {
756758
if (!isa<air::AsyncTokenType>(res.getType()))
757759
continue;
758-
for (auto u : res.getUsers()) {
759-
if (auto async_user = dyn_cast<air::AsyncOpInterface>(u)) {
760-
eraseAsyncDependencyFromAsyncOp(async_user, res);
761-
for (auto dep : getAsyncDependenciesFromOp(erase_op))
762-
if (dep != getLoopCarriedTokenFromScfOp(for_op, "argument"))
763-
air::addAsyncDependencyIfNew(u, dep);
764-
} else {
765-
// User op doesn't have air::AsyncOpInterface. Replace uses with newly
766-
// generated air.wait_all op.
767-
u->replaceUsesOfWith(
768-
res, rewriter
769-
.create<air::WaitAllOp>(
770-
loc, air::AsyncTokenType::get(rewriter.getContext()),
771-
getAsyncDependenciesFromOp(erase_op))
772-
.getAsyncToken());
773-
}
760+
if (erase_op_async_deps.empty())
761+
continue;
762+
else if (erase_op_async_deps.size() == 1)
763+
res.replaceAllUsesWith(erase_op_async_deps.front());
764+
else {
765+
res.replaceAllUsesWith(
766+
rewriter
767+
.create<air::WaitAllOp>(
768+
loc, air::AsyncTokenType::get(rewriter.getContext()),
769+
erase_op_async_deps)
770+
.getAsyncToken());
774771
}
775772
}
776773
}

0 commit comments

Comments
 (0)