-
Notifications
You must be signed in to change notification settings - Fork 78
insertDeallocate inspects inner scopes
#6007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f2ac6bb
678f6ba
fd49660
a92f4db
e496482
8219fe4
0cd2f48
aec9e85
a338a1c
80a701c
d66d161
dfd195c
a16de22
6bf14a6
d897e7c
08a5999
52f73dc
c72b706
b72eefb
2078dbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -8,10 +8,10 @@ | |||||
|
|
||||||
| #include "host_ir/allocate_and_deallocate.h" | ||||||
|
|
||||||
| #include <algorithm> | ||||||
| #include <functional> | ||||||
| #include <iterator> | ||||||
| #include <list> | ||||||
| #include <ranges> | ||||||
| #include <stack> | ||||||
| #include <unordered_map> | ||||||
| #include <unordered_set> | ||||||
|
|
@@ -24,91 +24,93 @@ namespace nvfuser::hir { | |||||
|
|
||||||
| namespace { | ||||||
|
|
||||||
| class DominatorTree { | ||||||
| class Node { | ||||||
| public: | ||||||
| class Node { | ||||||
| public: | ||||||
| Node(Scope* scope, Scope::Iterator iterator) | ||||||
| : scope_(scope), iterator_(iterator) {} | ||||||
| Node(const Node& other) = delete; | ||||||
| Node(Node&& other) = delete; | ||||||
| Node& operator=(const Node& other) = delete; | ||||||
| Node& operator=(Node&& other) = delete; | ||||||
|
|
||||||
| const std::vector<Node*>& children() const { | ||||||
| return children_; | ||||||
| } | ||||||
| Node(Scope* scope, Scope::Iterator iterator, const Node* parent) | ||||||
| : scope_(scope), iterator_(iterator), parent_(parent) {} | ||||||
| Node(const Node& other) = delete; | ||||||
| Node(Node&& other) = delete; | ||||||
| Node& operator=(const Node& other) = delete; | ||||||
| Node& operator=(Node&& other) = delete; | ||||||
|
|
||||||
| const std::vector<Node*>& children() const { | ||||||
| return children_; | ||||||
| } | ||||||
|
|
||||||
| void addChild(Node* child) { | ||||||
| children_.push_back(child); | ||||||
| } | ||||||
| void addChild(Node* child) { | ||||||
| children_.push_back(child); | ||||||
| } | ||||||
|
|
||||||
| Scope* scope() const { | ||||||
| return scope_; | ||||||
| } | ||||||
| Scope* scope() const { | ||||||
| return scope_; | ||||||
| } | ||||||
|
|
||||||
| Scope::Iterator iterator() const { | ||||||
| return iterator_; | ||||||
| } | ||||||
| Scope::Iterator iterator() const { | ||||||
| return iterator_; | ||||||
| } | ||||||
|
|
||||||
| Expr* getExpr() const { | ||||||
| return *iterator_; | ||||||
| } | ||||||
| Expr* getExpr() const { | ||||||
| return *iterator_; | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| // Consider putting `scope` and `iterator` into a separate Mutator class. | ||||||
| // They are only needed when the user wants to modify the host IR. | ||||||
| Scope* scope_; | ||||||
| Scope::Iterator iterator_; | ||||||
| const Node* parent() const { | ||||||
| return parent_; | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| Scope* scope_; | ||||||
| Scope::Iterator iterator_; | ||||||
| const Node* parent_; | ||||||
| std::vector<Node*> children_; | ||||||
| }; | ||||||
|
|
||||||
| std::vector<Node*> children_; | ||||||
| void depthFirstTraverse( | ||||||
| const Node* root, | ||||||
| const std::function<void(const Node*)>& pre_fn, | ||||||
| const std::function<void(const Node*)>& post_fn) { | ||||||
| struct Frame { | ||||||
| const Node* node; | ||||||
| bool processed; | ||||||
| }; | ||||||
|
|
||||||
| explicit DominatorTree(hir::HostIrContainer& hic) : hic_(hic) { | ||||||
| build(hic_.topLevel(), /*parent=*/nullptr); | ||||||
| std::stack<Frame> stack; | ||||||
| stack.push({root, /*processed=*/false}); | ||||||
| while (!stack.empty()) { | ||||||
| Frame& top = stack.top(); | ||||||
| if (top.processed) { | ||||||
| post_fn(top.node); | ||||||
| stack.pop(); | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| pre_fn(top.node); | ||||||
| top.processed = true; | ||||||
| for (const Node* child : top.node->children()) { | ||||||
| stack.push({child, /*processed=*/false}); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| class DominatorTree { | ||||||
| public: | ||||||
| explicit DominatorTree(hir::HostIrContainer& hic) : hic_(&hic) { | ||||||
| build(hic_->topLevel(), /*parent=*/nullptr); | ||||||
| } | ||||||
|
|
||||||
| const Node* getRoot() const { | ||||||
| const auto& top_level_exprs = hic_.topLevelExprs(); | ||||||
| const auto& top_level_exprs = hic_->topLevelExprs(); | ||||||
| NVF_ERROR(!top_level_exprs.empty()); | ||||||
| Expr* root = top_level_exprs.front(); | ||||||
| return &nodes_.at(root); | ||||||
| } | ||||||
|
|
||||||
| // `pre_fn` is called before traversing any child of a node. `post_fn` is | ||||||
| // called after traversing all children of a node. | ||||||
| void depthFirstTraverse( | ||||||
| const std::function<void(const Node*)>& pre_fn, | ||||||
| const std::function<void(const Node*)>& post_fn) const { | ||||||
| struct Frame { | ||||||
| const Node* node; | ||||||
| bool processed; | ||||||
| }; | ||||||
|
|
||||||
| std::stack<Frame> stack; | ||||||
| stack.emplace(getRoot(), /*processed=*/false); | ||||||
| while (!stack.empty()) { | ||||||
| Frame& top = stack.top(); | ||||||
| if (top.processed) { | ||||||
| post_fn(top.node); | ||||||
| stack.pop(); | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| pre_fn(top.node); | ||||||
| top.processed = true; | ||||||
| for (const Node* child : top.node->children()) { | ||||||
| stack.emplace(child, /*processed=*/false); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| void build(Scope& scope, Node* parent) { | ||||||
| for (auto scope_it = scope.exprs().begin(); scope_it != scope.exprs().end(); | ||||||
| ++scope_it) { | ||||||
| Expr* e = *scope_it; | ||||||
| auto [node_it, inserted] = nodes_.try_emplace(e, &scope, scope_it); | ||||||
| auto [node_it, inserted] = | ||||||
| nodes_.try_emplace(e, &scope, scope_it, parent); | ||||||
| NVF_ERROR(inserted); | ||||||
| Node& node = node_it->second; | ||||||
| if (parent != nullptr) { | ||||||
|
|
@@ -131,7 +133,49 @@ class DominatorTree { | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| hir::HostIrContainer& hic_; | ||||||
| hir::HostIrContainer* hic_; | ||||||
| std::unordered_map<const Expr*, Node> nodes_; | ||||||
| }; | ||||||
|
|
||||||
| class PostDominatorTree { | ||||||
| public: | ||||||
| explicit PostDominatorTree(hir::HostIrContainer& hic) : hic_(&hic) { | ||||||
| build(hic_->topLevel(), /*parent=*/nullptr); | ||||||
| } | ||||||
|
|
||||||
| const Node* getRoot() const { | ||||||
| const auto& top_level_exprs = hic_->topLevelExprs(); | ||||||
| NVF_ERROR(!top_level_exprs.empty()); | ||||||
| Expr* root = top_level_exprs.back(); | ||||||
| return &nodes_.at(root); | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| void build(Scope& scope, Node* parent) { | ||||||
| auto& exprs = scope.exprs(); | ||||||
| for (auto it = exprs.end(); it != exprs.begin();) { | ||||||
| --it; | ||||||
| Expr* e = *it; | ||||||
| auto [node_it, inserted] = nodes_.try_emplace(e, &scope, it, parent); | ||||||
| NVF_ERROR(inserted); | ||||||
| Node& node = node_it->second; | ||||||
| if (parent != nullptr) { | ||||||
| parent->addChild(&node); | ||||||
| } | ||||||
|
|
||||||
| if (auto* loop = dynamic_cast<hir::ForLoop*>(e)) { | ||||||
| build(loop->body(), &node); | ||||||
| } | ||||||
| if (auto* ite = dynamic_cast<kir::IfThenElse*>(e)) { | ||||||
| build(ite->thenBody(), &node); | ||||||
| build(ite->elseBody(), &node); | ||||||
| } | ||||||
|
|
||||||
| parent = &node; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| hir::HostIrContainer* hic_; | ||||||
| std::unordered_map<const Expr*, Node> nodes_; | ||||||
| }; | ||||||
|
|
||||||
|
|
@@ -157,9 +201,10 @@ void insertAllocations(hir::HostIrContainer& hic) { | |||||
| DominatorTree dom_tree(hic); | ||||||
| std::unordered_set<TensorView*> defined; | ||||||
|
|
||||||
| dom_tree.depthFirstTraverse( | ||||||
| depthFirstTraverse( | ||||||
| /*root=*/dom_tree.getRoot(), | ||||||
| /*pre_fn=*/ | ||||||
| [&](const DominatorTree::Node* node) { | ||||||
| [&](const Node* node) { | ||||||
| Expr* e = node->getExpr(); | ||||||
| // If `e`'s output needs preallocation but isn't defined, insert an | ||||||
| // allocation right before `e`. | ||||||
|
|
@@ -178,57 +223,102 @@ void insertAllocations(hir::HostIrContainer& hic) { | |||||
| } | ||||||
| }, | ||||||
| /*post_fn=*/ | ||||||
| [&](const DominatorTree::Node* node) { | ||||||
| [&](const Node* node) { | ||||||
| Expr* e = node->getExpr(); | ||||||
| for (auto* out : ir_utils::filterByType<TensorView>(e->outputs())) { | ||||||
| defined.erase(out); | ||||||
| } | ||||||
| }); | ||||||
| } | ||||||
|
|
||||||
| // For each TensorView that is allocated or used as an input, find its | ||||||
| // least common ancestor in the Post-dominator Tree — the latest point at which | ||||||
| // it can be deallocated. | ||||||
| std::unordered_map<TensorView*, const Node*> computeLeastCommonAncestor( | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad. |
||||||
| const PostDominatorTree& pdt) { | ||||||
| std::unordered_map<const Node*, int64_t> depth; | ||||||
|
|
||||||
| auto findLCA = [&](const Node* a, const Node* b) -> const Node* { | ||||||
| if (a == nullptr) { | ||||||
| return b; | ||||||
| } | ||||||
| if (b == nullptr) { | ||||||
| return a; | ||||||
| } | ||||||
| int64_t depth_a = depth.at(a); | ||||||
| int64_t depth_b = depth.at(b); | ||||||
| while (depth_a > depth_b) { | ||||||
| a = a->parent(); | ||||||
| depth_a--; | ||||||
| } | ||||||
| while (depth_b > depth_a) { | ||||||
| b = b->parent(); | ||||||
| depth_b--; | ||||||
| } | ||||||
| while (a != b) { | ||||||
| a = a->parent(); | ||||||
| b = b->parent(); | ||||||
| } | ||||||
| return a; | ||||||
| }; | ||||||
|
Comment on lines
+239
to
+263
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider making it a private method of class LowestCommonAncestors, and making depth LowestCommonAncestors::depth_. |
||||||
|
|
||||||
| std::unordered_map<TensorView*, const Node*> lca; | ||||||
| int64_t current_depth = -1; | ||||||
|
|
||||||
| depthFirstTraverse( | ||||||
| /*root=*/pdt.getRoot(), | ||||||
| /*pre_fn=*/ | ||||||
| [&](const Node* node) { | ||||||
| current_depth++; | ||||||
| depth[node] = current_depth; | ||||||
| Expr* e = node->getExpr(); | ||||||
|
|
||||||
| // Temporary special-case for kir::Allocate. We will switch | ||||||
| // inserting a new `hir::Allocate` in host IR lowering where | ||||||
| // the allocated `tv` will be the expr input. | ||||||
| if (auto* alloc = dynamic_cast<kir::Allocate*>(e)) { | ||||||
| auto* tv = alloc->buffer()->as<TensorView>(); | ||||||
| lca[tv] = findLCA(lca[tv], node); | ||||||
| } | ||||||
| for (auto* tv : ir_utils::filterByType<TensorView>(e->inputs())) { | ||||||
| lca[tv] = findLCA(lca[tv], node); | ||||||
| } | ||||||
| for (auto* tv : ir_utils::filterByType<TensorView>(e->outputs())) { | ||||||
| lca[tv] = findLCA(lca[tv], node); | ||||||
| } | ||||||
| }, | ||||||
| /*post_fn=*/ | ||||||
| [&](const Node*) { --current_depth; }); | ||||||
|
|
||||||
| return lca; | ||||||
| } | ||||||
|
|
||||||
| void insertDeallocations(hir::HostIrContainer& hic) { | ||||||
| const std::list<Expr*>& top_level_exprs = hic.topLevelExprs(); | ||||||
| std::for_each(top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) { | ||||||
| std::ranges::for_each(top_level_exprs, [](Expr* expr) { | ||||||
| NVF_ERROR( | ||||||
| !expr->isA<hir::Deallocate>(), | ||||||
| "Expected hostir container to not have deallocate, but found one " | ||||||
| "anyways: ", | ||||||
| expr); | ||||||
| }); | ||||||
|
|
||||||
| // For each input in every expression in the container, find the position of | ||||||
| // its last use and insert a deallocate directly after, except for fusion | ||||||
| // inputs and outputs. | ||||||
| std::unordered_set<TensorView*> last_use_found; | ||||||
| for (auto insertion_point = top_level_exprs.end(); | ||||||
| insertion_point != top_level_exprs.begin();) { | ||||||
| auto prev = std::prev(insertion_point); | ||||||
| Expr* e = *prev; | ||||||
|
|
||||||
| // Only tensors need to be allocated. | ||||||
| for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) { | ||||||
| // Fusion inputs are managed by the caller. | ||||||
| if (in->isFusionInput()) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // Fusion outputs need to be kept alive for the caller. | ||||||
| if (in->isFusionOutput()) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // Skip if `e` is not the last use. | ||||||
| if (!last_use_found.insert(in).second) { | ||||||
| continue; | ||||||
| } | ||||||
| PostDominatorTree pdt(hic); | ||||||
| const std::unordered_map<TensorView*, const Node*>& lca_map = | ||||||
| computeLeastCommonAncestor(pdt); | ||||||
|
|
||||||
| auto* deallocate = IrBuilder::create<hir::Deallocate>(in); | ||||||
| hic.insertExprBefore(insertion_point, deallocate); | ||||||
| // Insert deallocate at LCA for each tensorview that is not a fusion input or | ||||||
| // output. | ||||||
| for (const auto& [tv, lca_node] : lca_map) { | ||||||
| if (tv->isFusionInput() || tv->isFusionOutput()) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // Don't `--insertion_point;` because we'd like to skip newly inserted | ||||||
| // deallocations. | ||||||
| insertion_point = prev; | ||||||
| NVF_ERROR( | ||||||
| lca_node != nullptr, | ||||||
| "Could not find least common ancestor for all uses of ", | ||||||
| tv); | ||||||
| auto* deallocate = IrBuilder::create<hir::Deallocate>(tv); | ||||||
| lca_node->scope()->insert(std::next(lca_node->iterator()), deallocate); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.