@@ -752,25 +752,22 @@ scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
752
752
for (auto erase_op : target_ops) {
753
753
// Reconnect returned tokens.
754
754
rewriter.setInsertionPoint (erase_op);
755
+ SmallVector<Value> erase_op_async_deps =
756
+ air::getAsyncDependenciesFromOp (erase_op);
755
757
for (auto res : erase_op->getResults ()) {
756
758
if (!isa<air::AsyncTokenType>(res.getType ()))
757
759
continue ;
758
- for (auto u : res.getUsers ()) {
759
- if (auto async_user = dyn_cast<air::AsyncOpInterface>(u)) {
760
- eraseAsyncDependencyFromAsyncOp (async_user, res);
761
- for (auto dep : getAsyncDependenciesFromOp (erase_op))
762
- if (dep != getLoopCarriedTokenFromScfOp (for_op, " argument" ))
763
- air::addAsyncDependencyIfNew (u, dep);
764
- } else {
765
- // User op doesn't have air::AsyncOpInterface. Replace uses with newly
766
- // generated air.wait_all op.
767
- u->replaceUsesOfWith (
768
- res, rewriter
769
- .create <air::WaitAllOp>(
770
- loc, air::AsyncTokenType::get (rewriter.getContext ()),
771
- getAsyncDependenciesFromOp (erase_op))
772
- .getAsyncToken ());
773
- }
760
+ if (erase_op_async_deps.empty ())
761
+ continue ;
762
+ else if (erase_op_async_deps.size () == 1 )
763
+ res.replaceAllUsesWith (erase_op_async_deps.front ());
764
+ else {
765
+ res.replaceAllUsesWith (
766
+ rewriter
767
+ .create <air::WaitAllOp>(
768
+ loc, air::AsyncTokenType::get (rewriter.getContext ()),
769
+ erase_op_async_deps)
770
+ .getAsyncToken ());
774
771
}
775
772
}
776
773
}
0 commit comments