Skip to content

Commit 616afcf

Browse files
navahgarfacebook-github-bot
authored andcommitted
[jit] [shape analysis] Move constant tensors out of fused subgraphs during generalization (pytorch#70320)
Summary: Pull Request resolved: pytorch#70320 ghstack-source-id: 146514368 Test Plan: `buck test mode/dev-nosan //caffe2/test/cpp/jit:jit` Reviewed By: eellison Differential Revision: D33280508 fbshipit-source-id: fe4291d7c49f0a498b330de96b698e99f6f6a505
1 parent b60b1b1 commit 616afcf

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

test/cpp/jit/test_shape_analysis.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torch/csrc/jit/ir/ir.h>
88
#include <torch/csrc/jit/ir/ir_views.h>
99
#include <torch/csrc/jit/ir/irparser.h>
10+
#include <torch/csrc/jit/passes/constant_propagation.h>
1011
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
1112
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
1213
#include <torch/csrc/jit/runtime/graph_iterator.h>
@@ -246,5 +247,46 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
246247
}
247248
}
248249

250+
TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
251+
std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
252+
const auto graph_string = R"IR(
253+
graph(%x.1 : Tensor):
254+
%none : NoneType = prim::Constant()
255+
%size1 : int = prim::Constant[value=1]()
256+
%size10 : int = prim::Constant[value=10]()
257+
%sizes : int[] = prim::ListConstruct(%size10, %size1)
258+
%device : Device = prim::Constant[value="cpu"]()
259+
%10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none)
260+
%3 : Tensor = aten::tanh(%x.1)
261+
%29 : Tensor = aten::mul(%3, %10)
262+
return (%29))IR";
263+
torch::jit::parseIR(graph_string, subgraph.get());
264+
ConstantPropagation(subgraph);
265+
266+
std::shared_ptr<Graph> g = std::make_shared<Graph>();
267+
auto x_inp = g->addInput("x_inp");
268+
auto x_type = TensorType::create(at::rand({10, 5}));
269+
x_inp->setType(x_type);
270+
subgraph->inputs().at(0)->setType(x_type);
271+
auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
272+
output->node()->addInput(x_inp);
273+
output->node()->g_(attr::Subgraph, subgraph);
274+
275+
auto success = GenerateGuard(output->node());
276+
TORCH_INTERNAL_ASSERT(success);
277+
278+
// Check that the constants have been moved out of the fused graph.
279+
// This should result in not have any conditionals other than the one
280+
// checking the result of TensorExprDynamicGuard.
281+
testing::FileCheck()
282+
.check("TensorExprDynamicGuard")
283+
->check_next("prim::If")
284+
->check_not("prim::If") // no other IFs due to constants.
285+
->check("TensorExprGroup")
286+
->check("block1")
287+
->check("FallbackGraph")
288+
->run(*g);
289+
}
290+
249291
} // namespace jit
250292
} // namespace torch

torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,58 @@ bool TryGeneralizeInputDimensionsToSymbolicShapes(
106106
return true;
107107
}
108108

109+
void moveConstantTensorsOutOfSubgraph(
110+
Node* tensorexpr_graph_node,
111+
std::shared_ptr<Graph> tensorexpr_graph) {
112+
auto parent = tensorexpr_graph_node->owningGraph();
113+
114+
auto env = [&](Value* v) {
115+
TORCH_INTERNAL_ASSERT(
116+
false,
117+
"this should never happen since constant nodes do not have any inputs",
118+
v->debugName());
119+
return v;
120+
};
121+
122+
WithInsertPoint wip(tensorexpr_graph_node);
123+
std::vector<Node*> to_destroy;
124+
for (auto node : tensorexpr_graph->nodes()) {
125+
if (node->kind() == prim::Constant) {
126+
if (!node->output()->type()->cast<TensorType>()) {
127+
continue;
128+
}
129+
130+
// copy the constant and insert that copy into the parent graph.
131+
auto copy = parent->createClone(node, env);
132+
parent->insertNode(copy);
133+
134+
// add a new input to the te subgraph and replace the uses of the
135+
// constant with this input.
136+
auto new_const = tensorexpr_graph->addInput();
137+
new_const->setType(node->output()->type());
138+
node->output()->replaceAllUsesWith(new_const);
139+
140+
// add the copy as input to the te node
141+
tensorexpr_graph_node->addInput(copy->output());
142+
143+
to_destroy.push_back(node);
144+
}
145+
}
146+
147+
for (auto n : to_destroy) {
148+
n->destroy();
149+
}
150+
}
151+
109152
bool GenerateGuard(Node* tensorexpr_graph_node, bool add_composed_op) {
110153
auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node);
111154

155+
// Move constant tensors from the subgraph to the outer scope.
156+
// This is necessary because symbolic shape analysis does not handle the
157+
// case of broadcast(constant, symbolic_shape) well and that results in poor
158+
// performance.
159+
moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
160+
112161
// Generalize Inputs
113162
if (!TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph)) {
114163
return false;

0 commit comments

Comments
 (0)