Skip to content

Commit

Permalink
[StaticRuntime] Fix a bug that memory planner ignores subblocks (pyto…
Browse files Browse the repository at this point in the history
…rch#146728) (pytorch#146855)

Summary:

When Static Runtime graph node has sub-blocks, the memory planner does not consider sub-blocks' inputs as a node's input in memory planner. As the result, such nodes' inputs' lifetime is incorrect and corresponding tensor memory is released earlier than required and causes errors.

Differential Revision: D69195886

Pull Request resolved: pytorch#146855
Approved by: https://github.com/swolchok
  • Loading branch information
coufon authored and pytorchmergebot committed Feb 11, 2025
1 parent 15635b1 commit fc5913b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 8 deletions.
53 changes: 53 additions & 0 deletions benchmarks/static_runtime/test_static_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,59 @@ TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
}

TEST(ManagedTensorRanges, LifetimeIncludeSubBlockInputs) {
const std::string src_plain = R"IR(
graph(%cond : bool, %a : Tensor):
%b : Tensor = aten::mul(%a, %a)
%output : bool = prim::If(%cond)
block0():
-> (%a)
block1():
%c : Tensor = aten::mul(%b, %a)
-> (%c)
return (%output)
)IR";
const std::string src_recursive = R"IR(
graph(%cond : bool, %a : Tensor):
%b : Tensor = aten::mul(%a, %a)
%output : bool = prim::If(%cond)
block0():
-> (%a)
block1():
%outputblock1 : bool = prim::If(%cond)
block0():
-> (%a)
block1():
%c : Tensor = aten::mul(%b, %a)
-> (%c)
-> (%outputblock1)
return (%output)
)IR";

for (const auto& src : {src_plain, src_recursive}) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);

auto* b = vmap["b"];

FastSet<const Value*> managed_tensors = {b};
AliasDb alias_db(graph);
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);

std::vector<Node*> nodes(
graph->block()->nodes().begin(), graph->block()->nodes().end());
ASSERT_EQ(nodes.size(), 2);

EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[0]));

EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[1]));
EXPECT_EQ(
ranges.availableTensorValuesAfterNode(nodes[1]),
std::vector<const Value*>{b});
}
}

namespace {

// For checking the correctness of assignStorageToManageTensors, the following
Expand Down
28 changes: 20 additions & 8 deletions torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,25 @@ bool isPureFunction(const Node* node) {

} // namespace

void ManagedTensorRanges::extendLifetime(Value* input, size_t new_end) {
auto* lifetime = getLifetime(input);
if (lifetime) {
TORCH_DCHECK_LE(lifetime->end, new_end);
lifetime->end = new_end;
}
}

void ManagedTensorRanges::extendInputLifetime(Node* node, size_t new_end) {
for (auto* input : node->inputs()) {
extendLifetime(input, new_end);
}
for (auto* subblock : node->blocks()) {
for (auto* subnode : subblock->nodes()) {
extendInputLifetime(subnode, new_end);
}
}
}

ManagedTensorRanges::ManagedTensorRanges(
Block& block,
const AliasDb& alias_db,
Expand All @@ -404,14 +423,7 @@ ManagedTensorRanges::ManagedTensorRanges(
const auto num_nodes = static_cast<uint32_t>(nodes.size());
for (const auto i : c10::irange(num_nodes)) {
auto* node = nodes[i];
for (auto* input : node->inputs()) {
auto* lifetime = getLifetime(input);
if (!lifetime) {
continue;
}
DCHECK(lifetime->end <= i);
lifetime->end = i;
}
extendInputLifetime(node, i);
for (auto* output : node->outputs()) {
if (!alias_db.isMutableType(output)) {
continue;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/static/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class TORCH_API ManagedTensorRanges {
// type are mutable)
std::vector<const Value*> collectValuesWithTrackedLifetimes(
at::ArrayRef<const Value*> values);
void extendLifetime(Value* input, size_t new_end);
void extendInputLifetime(Node* node, size_t new_end);

// Maps Node* to the set of managed tensors that are now available
// for re-use after this node.
Expand Down

0 comments on commit fc5913b

Please sign in to comment.