Skip to content

Commit 0851cc4

Browse files
eellisoneellison
andauthored
Update freezing API - changes from 52337 (pytorch#52392)
Co-authored-by: eellison <[email protected]>
1 parent 804f7b6 commit 0851cc4

File tree

8 files changed

+79
-16
lines changed

8 files changed

+79
-16
lines changed

test/cpp/jit/test_module_api.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
#include <test/cpp/jit/test_utils.h>
44

55
#include <ATen/core/qualified_name.h>
6+
#include <torch/csrc/jit/api/module.h>
67
#include <torch/csrc/jit/frontend/resolver.h>
78
#include <torch/csrc/jit/serialization/import.h>
89
#include <torch/csrc/jit/serialization/import_source.h>
10+
#include <torch/csrc/jit/testing/file_check.h>
911
#include <torch/torch.h>
1012

1113
namespace torch {
@@ -341,6 +343,20 @@ TEST(ModuleAPITest, Define) {
341343
AT_ASSERT(result.toTensor().item<float>() == 6);
342344
}
343345

346+
TEST(ModuleAPITest, Freezing) {
347+
Module m("m");
348+
m.register_parameter("foo", torch::ones({}), false);
349+
m.define(R"(
350+
def forward(self, x, b : int = 4):
351+
return self.foo + x + b
352+
)");
353+
m.eval();
354+
auto frozen_mod = torch::jit::freeze(m);
355+
auto forward_g = frozen_mod.get_method("forward").graph();
356+
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
357+
;
358+
}
359+
344360
TEST(ModuleAPITest, To_CUDA) {
345361
Module m("test");
346362
{

test/jit/test_freezing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,7 @@ def test_optimize_freeze_module(self):
15081508
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
15091509
mod = torch.nn.Sequential(conv, bn)
15101510
# set optimize to False here, by default freezing runs optimize_frozen_module
1511-
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
1511+
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
15121512
# inspect frozen mod
15131513
FileCheck().check("batch_norm").run(frozen_mod.graph)
15141514
torch.jit.optimize_frozen_module(frozen_mod)

torch/_C/__init__.pyi.in

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def _freeze_module(module: ScriptModule,
174174
freeze_interfaces: _bool = True,
175175
preserveParameters: _bool = True) -> ScriptModule: ...
176176
def _jit_pass_optimize_frozen_graph(Graph) -> None: ...
177+
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
178+
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
179+
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
180+
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
181+
177182
def _is_tracing() -> _bool: ...
178183
def _jit_init() -> _bool: ...
179184
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...

torch/csrc/jit/api/module.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <torch/csrc/jit/frontend/schema_matching.h>
99
#include <torch/csrc/jit/jit_log.h>
1010
#include <torch/csrc/jit/passes/dead_code_elimination.h>
11+
#include <torch/csrc/jit/passes/freeze_module.h>
12+
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
1113
#include <torch/csrc/jit/passes/inliner.h>
1214
#include <torch/csrc/jit/runtime/operator.h>
1315

@@ -336,6 +338,21 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
336338
return obj;
337339
}
338340

341+
Module freeze(
342+
const Module& module,
343+
c10::optional<std::vector<std::string>> preserved_attrs,
344+
bool optimize_numerics) {
345+
TORCH_CHECK(
346+
module.is_training(),
347+
"Freezing is currently only implemented for modules in eval mode. Please call .eval() before freezing");
348+
349+
Module out_mod = freeze_module(
350+
module, preserved_attrs.value_or(std::vector<std::string>({})));
351+
auto graph = module.get_method("forward").graph();
352+
OptimizeFrozenGraph(graph, optimize_numerics);
353+
return out_mod;
354+
}
355+
339356
buffer_list Module::buffers(bool recurse) const {
340357
return buffer_list(*this, recurse, /*return_module=*/false);
341358
}

torch/csrc/jit/api/module.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ struct TORCH_API Module : public Object {
276276
bool non_blocking);
277277
};
278278

279+
// C++ equivalent api of `torch.jit.freeze`. See documentation there for
280+
// details.
281+
TORCH_API Module freeze(
282+
const Module& module,
283+
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
284+
bool optimize_numerics = true);
285+
279286
namespace detail {
280287

281288
struct TORCH_API SlotCursor {

torch/csrc/jit/passes/frozen_graph_optimizations.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
namespace torch {
99
namespace jit {
1010

11-
void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) {
11+
void OptimizeFrozenGraph(
12+
std::shared_ptr<Graph>& graph,
13+
bool optimize_numerics) {
1214
// run a couple times to capture Conv -> Mul -> Add etc
13-
for (size_t i = 0; i < 2; i++) {
14-
FoldFrozenConvBatchnorm(graph);
15-
FoldFrozenConvAddOrSub(graph);
16-
FoldFrozenConvMulOrDiv(graph);
15+
if (optimize_numerics) {
16+
for (size_t i = 0; i < 2; i++) {
17+
FoldFrozenConvBatchnorm(graph);
18+
FoldFrozenConvAddOrSub(graph);
19+
FoldFrozenConvMulOrDiv(graph);
20+
}
1721
}
1822
}
1923

torch/csrc/jit/passes/frozen_graph_optimizations.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
namespace torch {
1414
namespace jit {
1515

16-
TORCH_API void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph);
16+
TORCH_API void OptimizeFrozenGraph(
17+
std::shared_ptr<Graph>& graph,
18+
bool optimize_numerics = true);
1719

1820
} // namespace jit
1921
} // namespace torch

torch/jit/_freeze.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.jit._script import RecursiveScriptModule, ScriptModule
1111

1212

13-
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True):
13+
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True):
1414
r"""
1515
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
1616
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
@@ -26,10 +26,8 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = Tr
2626
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
2727
Attributes modified in preserved methods will also be preserved.
2828
29-
optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference,
30-
in addition to the graph cleanup that already occurs. The details of the optimizations can be found in
31-
`torch.jit.optimize_frozen_module.`
32-
29+
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
30+
preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.
3331
3432
Returns:
3533
Frozen :class:`ScriptModule`.
@@ -102,23 +100,29 @@ def forward(self, input):
102100

103101
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
104102
RecursiveScriptModule._finalize_scriptmodule(out)
105-
if optimize:
106-
optimize_frozen_module(out)
103+
optimize_frozen_module(out, optimize_numerics)
107104

108105
return out
109106

110107

111-
def optimize_frozen_module(mod):
108+
def optimize_frozen_module(mod, optimize_numerics: bool = True):
112109
r"""
113110
Runs a series of optimizations looking for patterns that occur in frozen graphs.
114111
The current set of optimizations is:
112+
- Dropout Removal
115113
- Conv -> Batchnorm folding
116114
- Conv -> Add/Sub folding
117115
- Conv -> Mul/Div folding
118116
119117
Args:
120118
mod (:class:`ScriptModule`): a frozen module to be optimized
121119
120+
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
121+
preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_allclose`
122+
when applied on a single transformation, however in a module where many transformations are applied
123+
the rtol or atol may no longer fall within the default `assert_allclose` tolerance. Conv -> Batchnorm folding,
124+
Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.
125+
122126
Returns:
123127
None
124128
@@ -140,4 +144,12 @@ def optimize_frozen_module(mod):
140144
assert "batch_norm" not in str(frozen_mod.graph)
141145
142146
"""
143-
torch._C._jit_pass_optimize_frozen_graph(mod.graph)
147+
# xxx: keep in sync with frozen_graph_optimization.cpp
148+
# intentionally duplicated to make to make it easier to create custom optimization sequence
149+
torch._C._jit_pass_remove_dropout(mod._c)
150+
if optimize_numerics:
151+
# run a couple times to capture Conv -> Mul -> Add etc
152+
for _ in range(2):
153+
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
154+
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
155+
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)

0 commit comments

Comments
 (0)