12
12
13
13
import executorch .backends .cadence .aot .ops_registrations # noqa
14
14
import torch
15
+ from executorch .backends .cadence .aot .compiler_funcs import (
16
+ convert as convert_fn ,
17
+ prepare as prepare_fn ,
18
+ trace as trace_fn ,
19
+ )
15
20
from executorch .backends .cadence .aot .memory_planning import (
16
21
CadenceMemoryPlanning ,
17
22
print_memory_planning_info ,
35
40
from executorch .exir .passes import ToOutVarPass
36
41
from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
37
42
from executorch .exir .program ._program import to_edge
38
- from torch ._inductor .decomposition import remove_decompositions
39
43
40
44
from torch .export .exported_program import ExportedProgram
41
- from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
42
45
43
46
from .passes import apply_exir_ops_passes , apply_torch_ops_passes
44
47
45
48
from .utils import print_ops_info
46
49
47
-
48
50
default_quantizer = CadenceDefaultQuantizer ()
49
51
50
52
@@ -62,13 +64,6 @@ def trace(
62
64
Trace the model with export and return an ExportedProgram.
63
65
"""
64
66
65
- # Make the model inference mode by calling model.eval()
66
- model .eval ()
67
-
68
- # Get default decompositions
69
- decomp_table = torch .export .default_decompositions ()
70
-
71
- # Select ops to keep
72
67
ops_to_keep = [
73
68
torch .ops .aten .conv1d .default ,
74
69
torch .ops .aten .conv2d .default ,
@@ -78,63 +73,54 @@ def trace(
78
73
torch .ops .aten .rms_norm .default ,
79
74
]
80
75
81
- # Remove decompositions for the ops we want to keep
82
- # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
83
- remove_decompositions (decomp_table , ops_to_keep )
84
-
85
- # Export with dynamo
86
- program = torch .export .export (model , inputs , strict = True ).run_decompositions (
87
- decomp_table
76
+ program = trace_fn (
77
+ model , inputs , is_qat = False , strict = True , ops_to_keep = ops_to_keep
88
78
)
89
79
90
80
if dump_graphs :
91
81
logging .info ("Graph before quantization:" )
92
- logging .info (program .module () .graph .print_tabular ())
82
+ logging .info (program .graph_module .graph .print_tabular ())
93
83
94
84
return program
95
85
96
86
97
- def prepare_and_convert_pt2 (
87
+ def prepare_pt2 (
98
88
program : ExportedProgram ,
99
- inputs : tuple [object , ...],
100
89
quantizer : CadenceQuantizer ,
101
- calibration_data : Optional [list [tuple [object , ...]]] = None ,
102
90
dump_graphs : bool = False ,
103
91
) -> torch .fx .GraphModule :
104
92
"""
105
- Prepare and convert a model using the given quantizer.
93
+ Prepare a model using the given quantizer.
106
94
The quantizer must be supplied and be the same as the one used to
107
95
fuse the model later, if applicable. If you do not expect that behavior,
108
96
please use quantize_and_fuse_pt2 instead, which will instantiate a
109
97
default quantizer for you if needed.
110
- If calibration data is provided, it will be used to calibrate the model. If
111
- not, the inputs will be used for calibration instead, which is useful for
112
- unit tests but should not be used for end-to-end use cases.
113
- Returns a GraphModule with the converted model.
98
+ Returns a GraphModule with the prepared model.
114
99
"""
115
100
116
- # Get the graph module from the ExportedProgram
117
- model_gm = program .module ()
101
+ prepared_model = prepare_fn (program , quantizer , is_qat = False )
118
102
119
- assert isinstance (model_gm , torch .fx .GraphModule )
103
+ if dump_graphs :
104
+ logging .info ("Graph after preparation:" )
105
+ logging .info (prepared_model .graph .print_tabular ())
120
106
121
- # Prepare
122
- prepared_model = prepare_pt2e (model_gm , quantizer )
107
+ return prepared_model
123
108
124
- # Calibrate
125
- # If no calibration data is provided, use the inputs
126
- if calibration_data is None :
127
- calibration_data = [inputs ]
128
109
129
- for samples in calibration_data :
130
- prepared_model (* samples )
110
+ def convert_pt2 (
111
+ graph_module : torch .fx .GraphModule ,
112
+ dump_graphs : bool = False ,
113
+ ) -> torch .fx .GraphModule :
114
+ """
115
+ Convert the model
116
+ Returns a GraphModule with the converted model.
117
+ """
131
118
132
- # Convert
133
- converted_model = convert_pt2e (prepared_model )
119
+ converted_model = convert_fn (graph_module )
134
120
135
121
if dump_graphs :
136
- logging .info ("Graph after quantization (before fusion) :" )
137
- logging .info (model_gm .graph .print_tabular ())
122
+ logging .info ("Graph after convert :" )
123
+ logging .info (converted_model .graph .print_tabular ())
138
124
139
125
return converted_model
140
126
@@ -192,10 +178,19 @@ def quantize_pt2(
192
178
logging .info ("Graph after trace:" )
193
179
logging .info (program .graph .print_tabular ())
194
180
181
+ # Get prepared graph module
182
+ prepared_gm = prepare_pt2 (program , quantizer , dump_graphs = dump_graphs )
183
+
184
+ # Calibrate
185
+ # If no calibration data is provided, use the inputs
186
+ if calibration_data is None :
187
+ calibration_data = [inputs ]
188
+
189
+ for samples in calibration_data :
190
+ prepared_gm (* samples )
191
+
195
192
# Get converted graph module
196
- converted_gm = prepare_and_convert_pt2 (
197
- program , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
198
- )
193
+ converted_gm = convert_pt2 (prepared_gm , dump_graphs = dump_graphs )
199
194
200
195
# Get fused model
201
196
fused_gm = fuse_pt2 (converted_gm , quantizer )
0 commit comments