15
15
from torch .testing ._internal .common_utils import IS_WINDOWS , TestCase , run_tests
16
16
17
17
from torchao .quantization .pt2e import (
18
- CUSTOM_KEY ,
19
- NUMERIC_DEBUG_HANDLE_KEY ,
18
+ FROM_NODE_KEY ,
20
19
compare_results ,
21
20
extract_results_from_loggers ,
22
- generate_numeric_debug_handle ,
23
21
prepare_for_propagation_comparison ,
24
22
)
23
+ from torchao .quantization .pt2e ._numeric_debugger import _generate_debug_handle_from_node
25
24
from torchao .quantization .pt2e .graph_utils import bfs_trace_with_node_process
26
25
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
27
26
from torchao .testing .pt2e ._xnnpack_quantizer import (
39
38
class TestNumericDebugger (TestCase ):
40
39
def _assert_each_node_has_debug_handle (self , model ) -> None :
41
40
def _assert_node_has_debug_handle (node ):
42
- self .assertTrue (
43
- CUSTOM_KEY in node . meta
44
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [ CUSTOM_KEY ] ,
45
- f"Node { node } doesn't have debug handle " ,
41
+ self .assertIn (
42
+ FROM_NODE_KEY ,
43
+ node .meta ,
44
+ f"Node { node } doesn't have from_node info " ,
46
45
)
47
46
48
47
bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
@@ -52,13 +51,8 @@ def _extract_debug_handles(self, model) -> dict[str, int]:
52
51
53
52
def _extract_debug_handles_from_node (node ):
54
53
nonlocal debug_handle_map
55
- if (
56
- CUSTOM_KEY in node .meta
57
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
58
- ):
59
- debug_handle_map [str (node )] = node .meta [CUSTOM_KEY ][
60
- NUMERIC_DEBUG_HANDLE_KEY
61
- ]
54
+ if (dh := _generate_debug_handle_from_node (node )) is not None :
55
+ debug_handle_map [str (node )] = dh
62
56
63
57
bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
64
58
@@ -69,12 +63,9 @@ def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]:
69
63
70
64
def _extract_debug_handles_with_prev_decomp_op_from_node (node ):
71
65
nonlocal prev_decomp_op_to_debug_handle_map
72
- if (
73
- CUSTOM_KEY in node .meta
74
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
75
- ):
66
+ if FROM_NODE_KEY in node .meta :
76
67
prev_decomp_op = str (node .meta .get ("nn_module_stack" ))
77
- debug_handle = node . meta [ CUSTOM_KEY ][ NUMERIC_DEBUG_HANDLE_KEY ]
68
+ debug_handle = _generate_debug_handle_from_node ( node )
78
69
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map :
79
70
prev_decomp_op_to_debug_handle_map [prev_decomp_op ] = debug_handle
80
71
else :
@@ -96,64 +87,73 @@ def test_simple(self):
96
87
m = TestHelperModules .Conv2dThenConv1d ()
97
88
example_inputs = m .example_inputs ()
98
89
ep = export_for_training (m , example_inputs , strict = True )
99
- generate_numeric_debug_handle ( ep )
100
- self ._assert_each_node_has_debug_handle (ep )
101
- debug_handle_map = self ._extract_debug_handles (ep )
90
+ m = ep . module ( )
91
+ self ._assert_each_node_has_debug_handle (m )
92
+ debug_handle_map = self ._extract_debug_handles (m )
102
93
103
94
self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
104
95
96
+ @unittest .skip ("debug flow not working on model with conditional control flow" )
105
97
def test_control_flow (self ):
106
98
m = TestHelperModules .ControlFlow ()
107
99
example_inputs = m .example_inputs ()
108
100
ep = export_for_training (m , example_inputs , strict = True )
109
- generate_numeric_debug_handle ( ep )
101
+ m = ep . module ( )
110
102
111
- self ._assert_each_node_has_debug_handle (ep )
112
- debug_handle_map = self ._extract_debug_handles (ep )
103
+ self ._assert_each_node_has_debug_handle (m )
104
+ debug_handle_map = self ._extract_debug_handles (m )
113
105
114
106
self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
115
107
116
108
def test_quantize_pt2e_preserve_handle (self ):
117
109
m = TestHelperModules .Conv2dThenConv1d ()
118
110
example_inputs = m .example_inputs ()
119
111
ep = export_for_training (m , example_inputs , strict = True )
120
- generate_numeric_debug_handle (ep )
121
112
m = ep .module ()
122
113
123
114
quantizer = XNNPACKQuantizer ().set_global (
124
115
get_symmetric_quantization_config (is_per_channel = False )
125
116
)
126
117
m = prepare_pt2e (m , quantizer )
127
118
debug_handle_map = self ._extract_debug_handles (m )
119
+ node_name_equip_with_output_observer = [
120
+ "conv2d" ,
121
+ "conv1d" ,
122
+ "squeeze" ,
123
+ ]
128
124
res_counter = Counter (debug_handle_map .values ())
129
- repeated_debug_handle_ids = [1 , 2 , 3 ]
125
+ repeated_debug_handle_ids = [
126
+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
127
+ ]
130
128
# 3 ids were repeated because we copy over the id from node to its output observer
131
129
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
132
130
for dh_id in repeated_debug_handle_ids :
133
131
self .assertEqual (res_counter [dh_id ], 2 )
134
132
135
133
m (* example_inputs )
136
134
m = convert_pt2e (m )
137
- self ._assert_each_node_has_debug_handle (ep )
135
+ self ._assert_each_node_has_debug_handle (m )
138
136
debug_handle_map = self ._extract_debug_handles (m )
139
137
res_counter = Counter (debug_handle_map .values ())
140
138
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
141
- # dequantize node
142
- repeated_debug_handle_ids = [1 , 2 , 3 ]
139
+ # quantize/dequantize node
140
+ repeated_debug_handle_ids = [
141
+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
142
+ ]
143
143
for dh_id in repeated_debug_handle_ids :
144
- self .assertEqual (res_counter [dh_id ], 2 )
144
+ self .assertEqual (res_counter [dh_id ], 3 )
145
145
146
146
def test_copy_preserve_handle (self ):
147
147
m = TestHelperModules .Conv2dThenConv1d ()
148
148
example_inputs = m .example_inputs ()
149
149
ep = torch .export .export (m , example_inputs , strict = True )
150
- generate_numeric_debug_handle ( ep )
150
+ m = ep . module ( )
151
151
152
- self ._assert_each_node_has_debug_handle (ep )
153
- debug_handle_map_ref = self ._extract_debug_handles (ep )
152
+ self ._assert_each_node_has_debug_handle (m )
153
+ debug_handle_map_ref = self ._extract_debug_handles (m )
154
154
155
155
ep_copy = copy .copy (ep )
156
- debug_handle_map = self ._extract_debug_handles (ep_copy )
156
+ debug_handle_map = self ._extract_debug_handles (ep_copy . module () )
157
157
158
158
self ._assert_each_node_has_debug_handle (ep )
159
159
self .assertEqual (debug_handle_map , debug_handle_map_ref )
@@ -162,13 +162,12 @@ def test_deepcopy_preserve_handle(self):
162
162
m = TestHelperModules .Conv2dThenConv1d ()
163
163
example_inputs = m .example_inputs ()
164
164
ep = torch .export .export (m , example_inputs , strict = True )
165
- generate_numeric_debug_handle (ep )
166
165
167
- debug_handle_map_ref = self ._extract_debug_handles (ep )
166
+ debug_handle_map_ref = self ._extract_debug_handles (ep . module () )
168
167
ep_copy = copy .deepcopy (ep )
169
- debug_handle_map = self ._extract_debug_handles (ep_copy )
168
+ debug_handle_map = self ._extract_debug_handles (ep_copy . module () )
170
169
171
- self ._assert_each_node_has_debug_handle (ep )
170
+ self ._assert_each_node_has_debug_handle (ep . module () )
172
171
self .assertEqual (debug_handle_map , debug_handle_map_ref )
173
172
174
173
@unittest .skip (
@@ -178,16 +177,16 @@ def test_re_export_preserve_handle(self):
178
177
m = TestHelperModules .Conv2dThenConv1d ()
179
178
example_inputs = m .example_inputs ()
180
179
ep = export_for_training (m , example_inputs , strict = True )
181
- generate_numeric_debug_handle (ep )
182
180
m = ep .module ()
183
181
184
- self ._assert_each_node_has_debug_handle (ep )
185
- debug_handle_map_ref = self ._extract_debug_handles (ep )
182
+ self ._assert_each_node_has_debug_handle (m )
183
+ debug_handle_map_ref = self ._extract_debug_handles (m )
186
184
187
185
ep_reexport = export_for_training (m , example_inputs , strict = True )
186
+ m_reexport = ep_reexport .module ()
188
187
189
- self ._assert_each_node_has_debug_handle (ep_reexport )
190
- debug_handle_map = self ._extract_debug_handles (ep_reexport )
188
+ self ._assert_each_node_has_debug_handle (m_reexport )
189
+ debug_handle_map = self ._extract_debug_handles (m_reexport )
191
190
192
191
self .assertEqual (debug_handle_map , debug_handle_map_ref )
193
192
@@ -198,16 +197,17 @@ def test_run_decompositions_same_handle_id(self):
198
197
m = TestHelperModules .Conv2dThenConv1d ()
199
198
example_inputs = m .example_inputs ()
200
199
ep = export_for_training (m , example_inputs , strict = True )
201
- generate_numeric_debug_handle ( ep )
200
+ m = ep . module ( )
202
201
203
- self ._assert_each_node_has_debug_handle (ep )
204
- debug_handle_map_ref = self ._extract_debug_handles (ep )
202
+ self ._assert_each_node_has_debug_handle (m )
203
+ debug_handle_map_ref = self ._extract_debug_handles (m )
205
204
206
205
ep_copy = copy .copy (ep )
207
206
ep_copy = ep_copy .run_decompositions ()
207
+ m_decomposed = ep_copy .module ()
208
208
209
- self ._assert_each_node_has_debug_handle (ep_copy )
210
- debug_handle_map = self ._extract_debug_handles (ep_copy )
209
+ self ._assert_each_node_has_debug_handle (m_decomposed )
210
+ debug_handle_map = self ._extract_debug_handles (m_decomposed )
211
211
212
212
# checking the map still has the same ids, the node may change
213
213
self .assertEqual (
@@ -226,18 +226,19 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
226
226
for m in test_models :
227
227
example_inputs = m .example_inputs ()
228
228
ep = export_for_training (m , example_inputs , strict = True )
229
- generate_numeric_debug_handle ( ep )
229
+ m = ep . module ( )
230
230
231
- self ._assert_each_node_has_debug_handle (ep )
231
+ self ._assert_each_node_has_debug_handle (m )
232
232
pre_decomp_to_debug_handle_map_ref = (
233
- self ._extract_debug_handles_with_prev_decomp_op (ep )
233
+ self ._extract_debug_handles_with_prev_decomp_op (m )
234
234
)
235
235
236
236
ep_copy = copy .copy (ep )
237
237
ep_copy = ep_copy .run_decompositions ()
238
- self ._assert_each_node_has_debug_handle (ep_copy )
238
+ m_decomposed = ep_copy .module ()
239
+ self ._assert_each_node_has_debug_handle (m_decomposed )
239
240
pre_decomp_to_debug_handle_map = (
240
- self ._extract_debug_handles_with_prev_decomp_op (ep_copy )
241
+ self ._extract_debug_handles_with_prev_decomp_op (m_decomposed )
241
242
)
242
243
243
244
# checking the map still has the same ids, the node may change
@@ -249,7 +250,6 @@ def test_prepare_for_propagation_comparison(self):
249
250
m = TestHelperModules .Conv2dThenConv1d ()
250
251
example_inputs = m .example_inputs ()
251
252
ep = export_for_training (m , example_inputs , strict = True )
252
- generate_numeric_debug_handle (ep )
253
253
m = ep .module ()
254
254
m_logger = prepare_for_propagation_comparison (m )
255
255
ref = m (* example_inputs )
@@ -266,7 +266,6 @@ def test_extract_results_from_loggers(self):
266
266
m = TestHelperModules .Conv2dThenConv1d ()
267
267
example_inputs = m .example_inputs ()
268
268
ep = export_for_training (m , example_inputs , strict = True )
269
- generate_numeric_debug_handle (ep )
270
269
m = ep .module ()
271
270
m_ref_logger = prepare_for_propagation_comparison (m )
272
271
@@ -291,7 +290,6 @@ def test_extract_results_from_loggers_list_output(self):
291
290
m = TestHelperModules .Conv2dWithSplit ()
292
291
example_inputs = m .example_inputs ()
293
292
ep = export_for_training (m , example_inputs , strict = True )
294
- generate_numeric_debug_handle (ep )
295
293
m = ep .module ()
296
294
m_ref_logger = prepare_for_propagation_comparison (m )
297
295
@@ -321,9 +319,10 @@ def test_added_node_gets_unique_id(self) -> None:
321
319
m = TestHelperModules .Conv2dThenConv1d ()
322
320
example_inputs = m .example_inputs ()
323
321
ep = export_for_training (m , example_inputs , strict = True )
324
- generate_numeric_debug_handle ( ep )
325
- ref_handles = self ._extract_debug_handles (ep )
322
+
323
+ ref_handles = self ._extract_debug_handles (ep . module () )
326
324
ref_counter = Counter (ref_handles .values ())
325
+
327
326
for k , v in ref_counter .items ():
328
327
self .assertEqual (
329
328
v ,
@@ -345,10 +344,10 @@ def test_added_node_gets_unique_id(self) -> None:
345
344
346
345
# Regenerate handles, make sure only the new relu node has a new id, and
347
346
# it doesn't clash with any of the existing ids.
348
- generate_numeric_debug_handle (ep )
349
347
350
- self ._assert_each_node_has_debug_handle (ep )
351
- handles_after_modification = self ._extract_debug_handles (ep )
348
+ m = ep .module ()
349
+ self ._assert_each_node_has_debug_handle (m )
350
+ handles_after_modification = self ._extract_debug_handles (m )
352
351
handles_counter = Counter (handles_after_modification .values ())
353
352
for name , handle in ref_handles .items ():
354
353
self .assertIn (name , handles_after_modification )
@@ -365,7 +364,7 @@ def test_added_node_gets_unique_id(self) -> None:
365
364
366
365
# Check for relu specifically. Avoid hardcoding the handle id since it
367
366
# may change with future node ordering changes.
368
- self .assertNotEqual (handles_after_modification ["relu_default" ], 0 )
367
+ self .assertNotIn (handles_after_modification ["relu_default" ], ref_counter )
369
368
self .assertEqual (handles_counter [handles_after_modification ["relu_default" ]], 1 )
370
369
371
370
0 commit comments