Skip to content

Commit 8da2ea6

Browse files
authored
move trace, prepare, convert into unified_compiler_utils, split prepare_and_convert into separate functions
Differential Revision: D78751687 Pull Request resolved: #12760
1 parent 6df857f commit 8da2ea6

File tree

4 files changed

+125
-46
lines changed

4 files changed

+125
-46
lines changed

backends/cadence/aot/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ python_library(
4141
":ops_registrations",
4242
":passes",
4343
":replace_ops",
44+
":compiler_funcs",
4445
":utils",
4546
"//caffe2:torch",
4647
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
@@ -332,6 +333,18 @@ python_library(
332333
],
333334
)
334335

336+
python_library(
337+
name = "compiler_funcs",
338+
srcs = [
339+
"compiler_funcs.py",
340+
],
341+
typing = True,
342+
deps = [
343+
"//caffe2:torch",
344+
"//pytorch/ao:torchao",
345+
],
346+
)
347+
335348

336349
python_unittest(
337350
name = "test_graph_builder",

backends/cadence/aot/compiler.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
15+
from executorch.backends.cadence.aot.compiler_funcs import (
16+
convert as convert_fn,
17+
prepare as prepare_fn,
18+
trace as trace_fn,
19+
)
1520
from executorch.backends.cadence.aot.memory_planning import (
1621
CadenceMemoryPlanning,
1722
print_memory_planning_info,
@@ -35,16 +40,13 @@
3540
from executorch.exir.passes import ToOutVarPass
3641
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3742
from executorch.exir.program._program import to_edge
38-
from torch._inductor.decomposition import remove_decompositions
3943

4044
from torch.export.exported_program import ExportedProgram
41-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4245

4346
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4447

4548
from .utils import print_ops_info
4649

47-
4850
default_quantizer = CadenceDefaultQuantizer()
4951

5052

@@ -62,13 +64,6 @@ def trace(
6264
Trace the model with export and return an ExportedProgram.
6365
"""
6466

65-
# Make the model inference mode by calling model.eval()
66-
model.eval()
67-
68-
# Get default decompositions
69-
decomp_table = torch.export.default_decompositions()
70-
71-
# Select ops to keep
7267
ops_to_keep = [
7368
torch.ops.aten.conv1d.default,
7469
torch.ops.aten.conv2d.default,
@@ -78,63 +73,54 @@ def trace(
7873
torch.ops.aten.rms_norm.default,
7974
]
8075

81-
# Remove decompositions for the ops we want to keep
82-
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
83-
remove_decompositions(decomp_table, ops_to_keep)
84-
85-
# Export with dynamo
86-
program = torch.export.export(model, inputs, strict=True).run_decompositions(
87-
decomp_table
76+
program = trace_fn(
77+
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
8878
)
8979

9080
if dump_graphs:
9181
logging.info("Graph before quantization:")
92-
logging.info(program.module().graph.print_tabular())
82+
logging.info(program.graph_module.graph.print_tabular())
9383

9484
return program
9585

9686

97-
def prepare_and_convert_pt2(
87+
def prepare_pt2(
9888
program: ExportedProgram,
99-
inputs: tuple[object, ...],
10089
quantizer: CadenceQuantizer,
101-
calibration_data: Optional[list[tuple[object, ...]]] = None,
10290
dump_graphs: bool = False,
10391
) -> torch.fx.GraphModule:
10492
"""
105-
Prepare and convert a model using the given quantizer.
93+
Prepare a model using the given quantizer.
10694
The quantizer must be supplied and be the same as the one used to
10795
fuse the model later, if applicable. If you do not expect that behavior,
10896
please use quantize_and_fuse_pt2 instead, which will instantiate a
10997
default quantizer for you if needed.
110-
If calibration data is provided, it will be used to calibrate the model. If
111-
not, the inputs will be used for calibration instead, which is useful for
112-
unit tests but should not be used for end-to-end use cases.
113-
Returns a GraphModule with the converted model.
98+
Returns a GraphModule with the prepared model.
11499
"""
115100

116-
# Get the graph module from the ExportedProgram
117-
model_gm = program.module()
101+
prepared_model = prepare_fn(program, quantizer, is_qat=False)
118102

119-
assert isinstance(model_gm, torch.fx.GraphModule)
103+
if dump_graphs:
104+
logging.info("Graph after preparation:")
105+
logging.info(prepared_model.graph.print_tabular())
120106

121-
# Prepare
122-
prepared_model = prepare_pt2e(model_gm, quantizer)
107+
return prepared_model
123108

124-
# Calibrate
125-
# If no calibration data is provided, use the inputs
126-
if calibration_data is None:
127-
calibration_data = [inputs]
128109

129-
for samples in calibration_data:
130-
prepared_model(*samples)
110+
def convert_pt2(
111+
graph_module: torch.fx.GraphModule,
112+
dump_graphs: bool = False,
113+
) -> torch.fx.GraphModule:
114+
"""
115+
Convert the model
116+
Returns a GraphModule with the converted model.
117+
"""
131118

132-
# Convert
133-
converted_model = convert_pt2e(prepared_model)
119+
converted_model = convert_fn(graph_module)
134120

135121
if dump_graphs:
136-
logging.info("Graph after quantization (before fusion):")
137-
logging.info(model_gm.graph.print_tabular())
122+
logging.info("Graph after convert:")
123+
logging.info(converted_model.graph.print_tabular())
138124

139125
return converted_model
140126

@@ -192,10 +178,19 @@ def quantize_pt2(
192178
logging.info("Graph after trace:")
193179
logging.info(program.graph.print_tabular())
194180

181+
# Get prepared graph module
182+
prepared_gm = prepare_pt2(program, quantizer, dump_graphs=dump_graphs)
183+
184+
# Calibrate
185+
# If no calibration data is provided, use the inputs
186+
if calibration_data is None:
187+
calibration_data = [inputs]
188+
189+
for samples in calibration_data:
190+
prepared_gm(*samples)
191+
195192
# Get converted graph module
196-
converted_gm = prepare_and_convert_pt2(
197-
program, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
198-
)
193+
converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs)
199194

200195
# Get fused model
201196
fused_gm = fuse_pt2(converted_gm, quantizer)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
10+
from typing import Optional
11+
12+
import torch
13+
from torch._inductor.decomposition import remove_decompositions
14+
from torchao.quantization.pt2e.quantize_pt2e import (
15+
convert_pt2e,
16+
prepare_pt2e,
17+
prepare_qat_pt2e,
18+
)
19+
from torchao.quantization.pt2e.quantizer import Quantizer
20+
21+
22+
@torch.no_grad()
23+
def trace(
24+
model: torch.nn.Module,
25+
inputs: tuple[object, ...],
26+
is_qat: bool = False,
27+
strict: bool = False,
28+
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
29+
) -> torch.export.ExportedProgram:
30+
if is_qat:
31+
model.train()
32+
else:
33+
model.eval()
34+
35+
decomp_table = torch.export.default_decompositions()
36+
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
37+
remove_decompositions(decomp_table, ops_to_keep)
38+
program = torch.export.export_for_training(
39+
model, inputs, strict=strict
40+
).run_decompositions(decomp_table)
41+
42+
return program
43+
44+
45+
def prepare(
46+
traced_program: torch.export.ExportedProgram,
47+
quantizer: Quantizer,
48+
is_qat: bool = False,
49+
) -> torch.fx.GraphModule:
50+
traced_model = traced_program.module()
51+
assert isinstance(traced_model, torch.fx.GraphModule)
52+
53+
if is_qat:
54+
prepared_model = prepare_qat_pt2e(traced_model, quantizer)
55+
else:
56+
prepared_model = prepare_pt2e(traced_model, quantizer)
57+
58+
return prepared_model
59+
60+
61+
def convert(prepared_model: torch.fx.GraphModule) -> torch.fx.GraphModule:
62+
converted_model = convert_pt2e(prepared_model)
63+
return converted_model

backends/cadence/aot/export_example.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
from typing import Any, Tuple
1616

1717
from executorch.backends.cadence.aot.compiler import (
18+
convert_pt2,
1819
export_to_executorch_gen_etrecord,
1920
fuse_pt2,
20-
prepare_and_convert_pt2,
21+
prepare_pt2,
2122
trace,
2223
)
2324

@@ -52,8 +53,15 @@ def export_model(
5253
# Trace the model
5354
ep = trace(model, example_inputs)
5455

56+
# Prepare the model
57+
prepared_gm = prepare_pt2(ep, quantizer)
58+
59+
# Calibrate the model
60+
for samples in [example_inputs]:
61+
prepared_gm(*samples)
62+
5563
# Convert the model
56-
converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer)
64+
converted_model = convert_pt2(prepared_gm)
5765

5866
# Get reference outputs from converted model
5967
ref_outputs = converted_model(*example_inputs)

0 commit comments

Comments
 (0)