1
- #include < torch/csrc/jit/passes/subgraph_rewrite.h >
1
+ #include " torch/csrc/jit/passes/dead_code_elimination.h "
2
2
3
3
#include " core/util/prelude.h"
4
4
@@ -7,86 +7,52 @@ namespace core {
7
7
namespace lowering {
8
8
namespace passes {
9
9
10
- void RemoveDropout (std::shared_ptr<torch::jit::Graph>& graph) {
11
- std::string dropout_pattern = R"IR(
12
- graph(%input, %4, %5):
13
- %6 = aten::dropout(%input, %4, %5)
14
- return (%6))IR" ;
15
- std::string no_dropout_pattern = R"IR(
16
- graph(%input, %4, %5):
17
- return (%input))IR" ;
18
-
19
- torch::jit::SubgraphRewriter remove_dropout;
20
- remove_dropout.RegisterRewritePattern (dropout_pattern, no_dropout_pattern);
21
- remove_dropout.runOnGraph (graph);
22
-
23
- std::string dropout_inplace_pattern = R"IR(
24
- graph(%input, %4, %5):
25
- %6 = aten::dropout_(%input, %4, %5)
26
- return (%6))IR" ;
27
- std::string no_dropout_inplace_pattern = R"IR(
28
- graph(%input, %4, %5):
29
- return (%input))IR" ;
30
-
31
- torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
32
- remove_dropout_inplace_pattern.RegisterRewritePattern (dropout_inplace_pattern, no_dropout_inplace_pattern);
33
- remove_dropout_inplace_pattern.runOnGraph (graph);
34
-
35
- // remove feature_dropout
36
- std::string feature_dropout_pattern = R"IR(
37
- graph(%input, %4, %5):
38
- %6 = aten::feature_dropout(%input, %4, %5)
39
- return (%6))IR" ;
40
- std::string no_feature_dropout_pattern = R"IR(
41
- graph(%input, %4, %5):
42
- return (%input))IR" ;
43
-
44
- torch::jit::SubgraphRewriter remove_feature_dropout_pattern;
45
- remove_feature_dropout_pattern.RegisterRewritePattern (feature_dropout_pattern, no_feature_dropout_pattern);
46
- remove_feature_dropout_pattern.runOnGraph (graph);
47
-
48
- // remove feature_dropout inplace
49
- std::string feature_dropout_inplace_pattern = R"IR(
50
- graph(%input, %4, %5):
51
- %6 = aten::feature_dropout_(%input, %4, %5)
52
- return (%6))IR" ;
53
- std::string no_feature_dropout_inplace_pattern = R"IR(
54
- graph(%input, %4, %5):
55
- return (%input))IR" ;
56
-
57
- torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern;
58
- remove_feature_dropout_inplace_pattern.RegisterRewritePattern (
59
- feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern);
60
- remove_feature_dropout_inplace_pattern.runOnGraph (graph);
61
-
62
- // remove feature_alpha_dropout
63
- std::string feature_alpha_dropout_pattern = R"IR(
64
- graph(%input, %4, %5):
65
- %6 = aten::feature_alpha_dropout(%input, %4, %5)
66
- return (%6))IR" ;
67
- std::string no_feature_alpha_dropout_pattern = R"IR(
68
- graph(%input, %4, %5):
69
- return (%input))IR" ;
70
-
71
- torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern;
72
- remove_feature_alpha_dropout_pattern.RegisterRewritePattern (
73
- feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern);
74
- remove_feature_alpha_dropout_pattern.runOnGraph (graph);
75
-
76
- // remove feature_alpha_dropout inplace
77
- std::string feature_alpha_dropout_inplace_pattern = R"IR(
78
- graph(%input, %4, %5):
79
- %6 = aten::feature_alpha_dropout_(%input, %4, %5)
80
- return (%6))IR" ;
81
- std::string no_feature_alpha_dropout_inplace_pattern = R"IR(
82
- graph(%input, %4, %5):
83
- return (%input))IR" ;
84
-
85
- torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern;
86
- remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern (
87
- feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern);
88
- remove_feature_alpha_dropout_inplace_pattern.runOnGraph (graph);
10
+ // Schemas for dropout variants
11
+ const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
12
+ c10::Symbol::fromQualString (" aten::dropout" ),
13
+ c10::Symbol::fromQualString (" aten::dropout_" ),
14
+ c10::Symbol::fromQualString (" aten::feature_dropout" ),
15
+ c10::Symbol::fromQualString (" aten::feature_dropout_" ),
16
+ c10::Symbol::fromQualString (" aten::feature_alpha_dropout" ),
17
+ c10::Symbol::fromQualString (" aten::feature_alpha_dropout_" ),
18
+ };
19
+
20
+ void removeDropoutInBlock (torch::jit::Block* block) {
21
+ /*
22
+ Function adapted from:
23
+ torch/csrc/jit/passes/remove_dropout.cpp
24
+
25
+ Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
26
+ */
27
+ std::vector<torch::jit::Node*> dropout_nodes_to_remove;
28
+
29
+ for (auto node : block->nodes ()) {
30
+ // Remove dropout for each member block within a node
31
+ for (auto block : node->blocks ()) {
32
+ removeDropoutInBlock (block);
33
+ }
34
+
35
+ // For each node having a dropout-variant Schema, remove the node
36
+ if (DropoutNodeKinds.find (node->kind ()) != DropoutNodeKinds.end ()) {
37
+ // Extract input and output tensors of dropout operator
38
+ auto input_value = node->inputs ()[0 ];
39
+ auto output_value = node->outputs ()[0 ];
40
+
41
+ output_value->replaceAllUsesWith (input_value);
42
+ dropout_nodes_to_remove.push_back (node);
43
+ }
44
+ }
45
+
46
+ // Delete dropout nodes
47
+ for (auto del_node : dropout_nodes_to_remove) {
48
+ del_node->destroy ();
49
+ }
50
+ }
89
51
52
+ void RemoveDropout (std::shared_ptr<torch::jit::Graph>& graph) {
53
+ // Remove all instances of dropout variants from graph
54
+ removeDropoutInBlock (graph->block ());
55
+ torch::jit::EliminateDeadCode (graph);
90
56
LOG_GRAPH (" Post remove dropout: " << *graph);
91
57
}
92
58
0 commit comments