Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,7 @@ if(BUILD_TEST)
list(APPEND HOSTIR_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_host_ir_evaluator.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_integration.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_passes.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_stream_lowering.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp
)
Expand Down
290 changes: 190 additions & 100 deletions csrc/host_ir/allocate_and_deallocate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

#include "host_ir/allocate_and_deallocate.h"

#include <algorithm>
#include <functional>
#include <iterator>
#include <list>
#include <ranges>
#include <stack>
#include <unordered_map>
#include <unordered_set>
Expand All @@ -24,91 +24,93 @@ namespace nvfuser::hir {

namespace {

class DominatorTree {
class Node {
public:
class Node {
public:
Node(Scope* scope, Scope::Iterator iterator)
: scope_(scope), iterator_(iterator) {}
Node(const Node& other) = delete;
Node(Node&& other) = delete;
Node& operator=(const Node& other) = delete;
Node& operator=(Node&& other) = delete;

const std::vector<Node*>& children() const {
return children_;
}
Node(Scope* scope, Scope::Iterator iterator, const Node* parent)
: scope_(scope), iterator_(iterator), parent_(parent) {}
Node(const Node& other) = delete;
Node(Node&& other) = delete;
Node& operator=(const Node& other) = delete;
Node& operator=(Node&& other) = delete;

const std::vector<Node*>& children() const {
return children_;
}

void addChild(Node* child) {
children_.push_back(child);
}
void addChild(Node* child) {
children_.push_back(child);
}

Scope* scope() const {
return scope_;
}
Scope* scope() const {
return scope_;
}

Scope::Iterator iterator() const {
return iterator_;
}
Scope::Iterator iterator() const {
return iterator_;
}

Expr* getExpr() const {
return *iterator_;
}
Expr* getExpr() const {
return *iterator_;
}

private:
// Consider putting `scope` and `iterator` into a separate Mutator class.
// They are only needed when the user wants to modify the host IR.
Scope* scope_;
Scope::Iterator iterator_;
const Node* parent() const {
return parent_;
}

private:
Scope* scope_;
Scope::Iterator iterator_;
const Node* parent_;
std::vector<Node*> children_;
};

std::vector<Node*> children_;
void depthFirstTraverse(
const Node* root,
const std::function<void(const Node*)>& pre_fn,
const std::function<void(const Node*)>& post_fn) {
struct Frame {
const Node* node;
bool processed;
};

explicit DominatorTree(hir::HostIrContainer& hic) : hic_(hic) {
build(hic_.topLevel(), /*parent=*/nullptr);
std::stack<Frame> stack;
stack.push({root, /*processed=*/false});
while (!stack.empty()) {
Frame& top = stack.top();
if (top.processed) {
post_fn(top.node);
stack.pop();
continue;
}

pre_fn(top.node);
top.processed = true;
for (const Node* child : top.node->children()) {
stack.push({child, /*processed=*/false});
}
}
}

class DominatorTree {
public:
explicit DominatorTree(hir::HostIrContainer& hic) : hic_(&hic) {
build(hic_->topLevel(), /*parent=*/nullptr);
}

const Node* getRoot() const {
const auto& top_level_exprs = hic_.topLevelExprs();
const auto& top_level_exprs = hic_->topLevelExprs();
NVF_ERROR(!top_level_exprs.empty());
Expr* root = top_level_exprs.front();
return &nodes_.at(root);
}

// `pre_fn` is called before traversing any child of a node. `post_fn` is
// called after traversing all children of a node.
void depthFirstTraverse(
const std::function<void(const Node*)>& pre_fn,
const std::function<void(const Node*)>& post_fn) const {
struct Frame {
const Node* node;
bool processed;
};

std::stack<Frame> stack;
stack.emplace(getRoot(), /*processed=*/false);
while (!stack.empty()) {
Frame& top = stack.top();
if (top.processed) {
post_fn(top.node);
stack.pop();
continue;
}

pre_fn(top.node);
top.processed = true;
for (const Node* child : top.node->children()) {
stack.emplace(child, /*processed=*/false);
}
}
}

private:
void build(Scope& scope, Node* parent) {
for (auto scope_it = scope.exprs().begin(); scope_it != scope.exprs().end();
++scope_it) {
Expr* e = *scope_it;
auto [node_it, inserted] = nodes_.try_emplace(e, &scope, scope_it);
auto [node_it, inserted] =
nodes_.try_emplace(e, &scope, scope_it, parent);
NVF_ERROR(inserted);
Node& node = node_it->second;
if (parent != nullptr) {
Expand All @@ -131,7 +133,49 @@ class DominatorTree {
}
}

hir::HostIrContainer& hic_;
hir::HostIrContainer* hic_;
std::unordered_map<const Expr*, Node> nodes_;
};

class PostDominatorTree {
public:
explicit PostDominatorTree(hir::HostIrContainer& hic) : hic_(&hic) {
build(hic_->topLevel(), /*parent=*/nullptr);
}

const Node* getRoot() const {
const auto& top_level_exprs = hic_->topLevelExprs();
NVF_ERROR(!top_level_exprs.empty());
Expr* root = top_level_exprs.back();
return &nodes_.at(root);
}

private:
void build(Scope& scope, Node* parent) {
auto& exprs = scope.exprs();
for (auto it = exprs.end(); it != exprs.begin();) {
--it;
Expr* e = *it;
auto [node_it, inserted] = nodes_.try_emplace(e, &scope, it, parent);
NVF_ERROR(inserted);
Node& node = node_it->second;
if (parent != nullptr) {
parent->addChild(&node);
}

if (auto* loop = dynamic_cast<hir::ForLoop*>(e)) {
build(loop->body(), &node);
}
if (auto* ite = dynamic_cast<kir::IfThenElse*>(e)) {
build(ite->thenBody(), &node);
build(ite->elseBody(), &node);
}

parent = &node;
}
}

hir::HostIrContainer* hic_;
std::unordered_map<const Expr*, Node> nodes_;
};

Expand All @@ -157,9 +201,10 @@ void insertAllocations(hir::HostIrContainer& hic) {
DominatorTree dom_tree(hic);
std::unordered_set<TensorView*> defined;

dom_tree.depthFirstTraverse(
depthFirstTraverse(
/*root=*/dom_tree.getRoot(),
/*pre_fn=*/
[&](const DominatorTree::Node* node) {
[&](const Node* node) {
Expr* e = node->getExpr();
// If `e`'s output needs preallocation but isn't defined, insert an
// allocation right before `e`.
Expand All @@ -178,57 +223,102 @@ void insertAllocations(hir::HostIrContainer& hic) {
}
},
/*post_fn=*/
[&](const DominatorTree::Node* node) {
[&](const Node* node) {
Expr* e = node->getExpr();
for (auto* out : ir_utils::filterByType<TensorView>(e->outputs())) {
defined.erase(out);
}
});
}

// For each TensorView that is allocated or used as an input, find its
// least common ancestor in the Post-dominator Tree — the latest point at which
// it can be deallocated.
std::unordered_map<TensorView*, const Node*> computeLeastCommonAncestor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unordered_map<TensorView*, const Node*> computeLeastCommonAncestor(
std::unordered_map<TensorView*, const Node*> computeLowestCommonAncestor(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad.

const PostDominatorTree& pdt) {
std::unordered_map<const Node*, int64_t> depth;

auto findLCA = [&](const Node* a, const Node* b) -> const Node* {
if (a == nullptr) {
return b;
}
if (b == nullptr) {
return a;
}
int64_t depth_a = depth.at(a);
int64_t depth_b = depth.at(b);
while (depth_a > depth_b) {
a = a->parent();
depth_a--;
}
while (depth_b > depth_a) {
b = b->parent();
depth_b--;
}
while (a != b) {
a = a->parent();
b = b->parent();
}
return a;
};
Comment on lines +239 to +263
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making it a private method of class LowestCommonAncestors, and making depth LowestCommonAncestors::depth_.


std::unordered_map<TensorView*, const Node*> lca;
int64_t current_depth = -1;

depthFirstTraverse(
/*root=*/pdt.getRoot(),
/*pre_fn=*/
[&](const Node* node) {
current_depth++;
depth[node] = current_depth;
Expr* e = node->getExpr();

// Temporary special-case for kir::Allocate. We will switch
// inserting a new `hir::Allocate` in host IR lowering where
// the allocated `tv` will be the expr input.
if (auto* alloc = dynamic_cast<kir::Allocate*>(e)) {
auto* tv = alloc->buffer()->as<TensorView>();
lca[tv] = findLCA(lca[tv], node);
}
for (auto* tv : ir_utils::filterByType<TensorView>(e->inputs())) {
lca[tv] = findLCA(lca[tv], node);
}
for (auto* tv : ir_utils::filterByType<TensorView>(e->outputs())) {
lca[tv] = findLCA(lca[tv], node);
}
},
/*post_fn=*/
[&](const Node*) { --current_depth; });

return lca;
}

void insertDeallocations(hir::HostIrContainer& hic) {
const std::list<Expr*>& top_level_exprs = hic.topLevelExprs();
std::for_each(top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) {
std::ranges::for_each(top_level_exprs, [](Expr* expr) {
NVF_ERROR(
!expr->isA<hir::Deallocate>(),
"Expected hostir container to not have deallocate, but found one "
"anyways: ",
expr);
});

// For each input in every expression in the container, find the position of
// its last use and insert a deallocate directly after, except for fusion
// inputs and outputs.
std::unordered_set<TensorView*> last_use_found;
for (auto insertion_point = top_level_exprs.end();
insertion_point != top_level_exprs.begin();) {
auto prev = std::prev(insertion_point);
Expr* e = *prev;

// Only tensors need to be allocated.
for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) {
// Fusion inputs are managed by the caller.
if (in->isFusionInput()) {
continue;
}

// Fusion outputs need to be kept alive for the caller.
if (in->isFusionOutput()) {
continue;
}

// Skip if `e` is not the last use.
if (!last_use_found.insert(in).second) {
continue;
}
PostDominatorTree pdt(hic);
const std::unordered_map<TensorView*, const Node*>& lca_map =
computeLeastCommonAncestor(pdt);

auto* deallocate = IrBuilder::create<hir::Deallocate>(in);
hic.insertExprBefore(insertion_point, deallocate);
// Insert deallocate at LCA for each tensorview that is not a fusion input or
// output.
for (const auto& [tv, lca_node] : lca_map) {
if (tv->isFusionInput() || tv->isFusionOutput()) {
continue;
}

// Don't `--insertion_point;` because we'd like to skip newly inserted
// deallocations.
insertion_point = prev;
NVF_ERROR(
lca_node != nullptr,
"Could not find least common ancestor for all uses of ",
tv);
auto* deallocate = IrBuilder::create<hir::Deallocate>(tv);
lca_node->scope()->insert(std::next(lca_node->iterator()), deallocate);
}
}

Expand Down
Loading