Skip to content

Commit b2b6871

Browse files
authored
fix: Replace RemoveDropout lowering pass implementation with modified JIT pass (#1589)
1 parent 675bdfd commit b2b6871

File tree

2 files changed

+98
-80
lines changed

2 files changed

+98
-80
lines changed

core/lowering/passes/remove_dropout.cpp

Lines changed: 46 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
1+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
22

33
#include "core/util/prelude.h"
44

@@ -7,86 +7,52 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10-
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
11-
std::string dropout_pattern = R"IR(
12-
graph(%input, %4, %5):
13-
%6 = aten::dropout(%input, %4, %5)
14-
return (%6))IR";
15-
std::string no_dropout_pattern = R"IR(
16-
graph(%input, %4, %5):
17-
return (%input))IR";
18-
19-
torch::jit::SubgraphRewriter remove_dropout;
20-
remove_dropout.RegisterRewritePattern(dropout_pattern, no_dropout_pattern);
21-
remove_dropout.runOnGraph(graph);
22-
23-
std::string dropout_inplace_pattern = R"IR(
24-
graph(%input, %4, %5):
25-
%6 = aten::dropout_(%input, %4, %5)
26-
return (%6))IR";
27-
std::string no_dropout_inplace_pattern = R"IR(
28-
graph(%input, %4, %5):
29-
return (%input))IR";
30-
31-
torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
32-
remove_dropout_inplace_pattern.RegisterRewritePattern(dropout_inplace_pattern, no_dropout_inplace_pattern);
33-
remove_dropout_inplace_pattern.runOnGraph(graph);
34-
35-
// remove feature_dropout
36-
std::string feature_dropout_pattern = R"IR(
37-
graph(%input, %4, %5):
38-
%6 = aten::feature_dropout(%input, %4, %5)
39-
return (%6))IR";
40-
std::string no_feature_dropout_pattern = R"IR(
41-
graph(%input, %4, %5):
42-
return (%input))IR";
43-
44-
torch::jit::SubgraphRewriter remove_feature_dropout_pattern;
45-
remove_feature_dropout_pattern.RegisterRewritePattern(feature_dropout_pattern, no_feature_dropout_pattern);
46-
remove_feature_dropout_pattern.runOnGraph(graph);
47-
48-
// remove feature_dropout inplace
49-
std::string feature_dropout_inplace_pattern = R"IR(
50-
graph(%input, %4, %5):
51-
%6 = aten::feature_dropout_(%input, %4, %5)
52-
return (%6))IR";
53-
std::string no_feature_dropout_inplace_pattern = R"IR(
54-
graph(%input, %4, %5):
55-
return (%input))IR";
56-
57-
torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern;
58-
remove_feature_dropout_inplace_pattern.RegisterRewritePattern(
59-
feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern);
60-
remove_feature_dropout_inplace_pattern.runOnGraph(graph);
61-
62-
// remove feature_alpha_dropout
63-
std::string feature_alpha_dropout_pattern = R"IR(
64-
graph(%input, %4, %5):
65-
%6 = aten::feature_alpha_dropout(%input, %4, %5)
66-
return (%6))IR";
67-
std::string no_feature_alpha_dropout_pattern = R"IR(
68-
graph(%input, %4, %5):
69-
return (%input))IR";
70-
71-
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern;
72-
remove_feature_alpha_dropout_pattern.RegisterRewritePattern(
73-
feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern);
74-
remove_feature_alpha_dropout_pattern.runOnGraph(graph);
75-
76-
// remove feature_alpha_dropout inplace
77-
std::string feature_alpha_dropout_inplace_pattern = R"IR(
78-
graph(%input, %4, %5):
79-
%6 = aten::feature_alpha_dropout_(%input, %4, %5)
80-
return (%6))IR";
81-
std::string no_feature_alpha_dropout_inplace_pattern = R"IR(
82-
graph(%input, %4, %5):
83-
return (%input))IR";
84-
85-
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern;
86-
remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern(
87-
feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern);
88-
remove_feature_alpha_dropout_inplace_pattern.runOnGraph(graph);
10+
// Schemas for dropout variants
11+
const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
12+
c10::Symbol::fromQualString("aten::dropout"),
13+
c10::Symbol::fromQualString("aten::dropout_"),
14+
c10::Symbol::fromQualString("aten::feature_dropout"),
15+
c10::Symbol::fromQualString("aten::feature_dropout_"),
16+
c10::Symbol::fromQualString("aten::feature_alpha_dropout"),
17+
c10::Symbol::fromQualString("aten::feature_alpha_dropout_"),
18+
};
19+
20+
void removeDropoutInBlock(torch::jit::Block* block) {
21+
/*
22+
Function adapted from:
23+
torch/csrc/jit/passes/remove_dropout.cpp
24+
25+
Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
26+
*/
27+
std::vector<torch::jit::Node*> dropout_nodes_to_remove;
28+
29+
for (auto node : block->nodes()) {
30+
// Remove dropout for each member block within a node
31+
for (auto block : node->blocks()) {
32+
removeDropoutInBlock(block);
33+
}
34+
35+
// For each node having a dropout-variant Schema, remove the node
36+
if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) {
37+
// Extract input and output tensors of dropout operator
38+
auto input_value = node->inputs()[0];
39+
auto output_value = node->outputs()[0];
40+
41+
output_value->replaceAllUsesWith(input_value);
42+
dropout_nodes_to_remove.push_back(node);
43+
}
44+
}
45+
46+
// Delete dropout nodes
47+
for (auto del_node : dropout_nodes_to_remove) {
48+
del_node->destroy();
49+
}
50+
}
8951

52+
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
53+
// Remove all instances of dropout variants from graph
54+
removeDropoutInBlock(graph->block());
55+
torch::jit::EliminateDeadCode(graph);
9056
LOG_GRAPH("Post remove dropout: " << *graph);
9157
}
9258

tests/core/lowering/test_remove_dropout_pass.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,32 @@ TEST(LoweringPasses, RemoveDropoutLowersCorrectly) {
3232
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
3333
}
3434

35+
TEST(LoweringPasses, RemoveDropoutNestedLowersCorrectly) {
36+
std::string source_graph = R"IR(
37+
graph(%x.1):
38+
%3 : float = prim::Constant[value=0.5]()
39+
%4 : bool = prim::Constant[value=0]()
40+
%y.1 : Tensor = aten::dropout(%x.1, %3, %4)
41+
%z.1 : Tensor = aten::dropout(%y.1, %3, %4)
42+
%12 : Tensor = aten::relu(%z.1)
43+
return (%12))IR";
44+
std::string target_graph = R"IR(
45+
graph(%x.1):
46+
%11 : Tensor = aten::relu(%x.1)
47+
return (%11))IR";
48+
49+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
50+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
51+
auto sg = std::make_shared<torch::jit::Graph>();
52+
torch::jit::parseIR(source_graph, sg.get());
53+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
54+
55+
auto tg = std::make_shared<torch::jit::Graph>();
56+
torch::jit::parseIR(target_graph, tg.get());
57+
58+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
59+
}
60+
3561
TEST(LoweringPasses, RemoveDropoutInplaceLowersCorrectly) {
3662
std::string source_graph = R"IR(
3763
graph(%x.1):
@@ -132,6 +158,32 @@ TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
132158
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
133159
}
134160

161+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) {
162+
std::string source_graph = R"IR(
163+
graph(%x.1):
164+
%3 : float = prim::Constant[value=0.5]()
165+
%4 : bool = prim::Constant[value=0]()
166+
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
167+
%z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4)
168+
%12 : Tensor = aten::relu(%z.1)
169+
return (%12))IR";
170+
std::string target_graph = R"IR(
171+
graph(%x.1):
172+
%11 : Tensor = aten::relu(%x.1)
173+
return (%11))IR";
174+
175+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
176+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
177+
auto sg = std::make_shared<torch::jit::Graph>();
178+
torch::jit::parseIR(source_graph, sg.get());
179+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
180+
181+
auto tg = std::make_shared<torch::jit::Graph>();
182+
torch::jit::parseIR(target_graph, tg.get());
183+
184+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
185+
}
186+
135187
TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
136188
std::string source_graph = R"IR(
137189
graph(%x.1):

0 commit comments

Comments
 (0)