15
15
from torch .testing ._internal .common_utils import IS_WINDOWS , run_tests
16
16
17
17
from torchao .quantization .pt2e import (
18
- generate_numeric_debug_handle ,
19
18
prepare_for_propagation_comparison ,
20
19
)
21
20
from torchao .testing .pt2e .utils import PT2ENumericDebuggerTestCase
@@ -35,34 +34,35 @@ def test_simple(self):
35
34
m = TestHelperModules .Conv2dThenConv1d ()
36
35
example_inputs = m .example_inputs ()
37
36
ep = export_for_training (m , example_inputs , strict = True )
38
- generate_numeric_debug_handle ( ep )
39
- self ._assert_each_node_has_debug_handle (ep )
40
- debug_handle_map = self ._extract_debug_handles (ep )
37
+ m = ep . module ( )
38
+ self ._assert_each_node_has_debug_handle (m )
39
+ debug_handle_map = self ._extract_debug_handles (m )
41
40
42
41
self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
43
42
43
+ @unittest .skip ("debug flow not working on model with conditional control flow" )
44
44
def test_control_flow (self ):
45
45
m = TestHelperModules .ControlFlow ()
46
46
example_inputs = m .example_inputs ()
47
47
ep = export_for_training (m , example_inputs , strict = True )
48
- generate_numeric_debug_handle ( ep )
48
+ m = ep . module ( )
49
49
50
- self ._assert_each_node_has_debug_handle (ep )
51
- debug_handle_map = self ._extract_debug_handles (ep )
50
+ self ._assert_each_node_has_debug_handle (m )
51
+ debug_handle_map = self ._extract_debug_handles (m )
52
52
53
53
self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
54
54
55
55
def test_copy_preserve_handle (self ):
56
56
m = TestHelperModules .Conv2dThenConv1d ()
57
57
example_inputs = m .example_inputs ()
58
58
ep = torch .export .export (m , example_inputs , strict = True )
59
- generate_numeric_debug_handle ( ep )
59
+ m = ep . module ( )
60
60
61
- self ._assert_each_node_has_debug_handle (ep )
62
- debug_handle_map_ref = self ._extract_debug_handles (ep )
61
+ self ._assert_each_node_has_debug_handle (m )
62
+ debug_handle_map_ref = self ._extract_debug_handles (m )
63
63
64
64
ep_copy = copy .copy (ep )
65
- debug_handle_map = self ._extract_debug_handles (ep_copy )
65
+ debug_handle_map = self ._extract_debug_handles (ep_copy . module () )
66
66
67
67
self ._assert_each_node_has_debug_handle (ep )
68
68
self .assertEqual (debug_handle_map , debug_handle_map_ref )
@@ -71,13 +71,12 @@ def test_deepcopy_preserve_handle(self):
71
71
m = TestHelperModules .Conv2dThenConv1d ()
72
72
example_inputs = m .example_inputs ()
73
73
ep = torch .export .export (m , example_inputs , strict = True )
74
- generate_numeric_debug_handle (ep )
75
74
76
- debug_handle_map_ref = self ._extract_debug_handles (ep )
75
+ debug_handle_map_ref = self ._extract_debug_handles (ep . module () )
77
76
ep_copy = copy .deepcopy (ep )
78
- debug_handle_map = self ._extract_debug_handles (ep_copy )
77
+ debug_handle_map = self ._extract_debug_handles (ep_copy . module () )
79
78
80
- self ._assert_each_node_has_debug_handle (ep )
79
+ self ._assert_each_node_has_debug_handle (ep . module () )
81
80
self .assertEqual (debug_handle_map , debug_handle_map_ref )
82
81
83
82
@unittest .skip (
@@ -87,16 +86,16 @@ def test_re_export_preserve_handle(self):
87
86
m = TestHelperModules .Conv2dThenConv1d ()
88
87
example_inputs = m .example_inputs ()
89
88
ep = export_for_training (m , example_inputs , strict = True )
90
- generate_numeric_debug_handle (ep )
91
89
m = ep .module ()
92
90
93
- self ._assert_each_node_has_debug_handle (ep )
94
- debug_handle_map_ref = self ._extract_debug_handles (ep )
91
+ self ._assert_each_node_has_debug_handle (m )
92
+ debug_handle_map_ref = self ._extract_debug_handles (m )
95
93
96
94
ep_reexport = export_for_training (m , example_inputs , strict = True )
95
+ m_reexport = ep_reexport .module ()
97
96
98
- self ._assert_each_node_has_debug_handle (ep_reexport )
99
- debug_handle_map = self ._extract_debug_handles (ep_reexport )
97
+ self ._assert_each_node_has_debug_handle (m_reexport )
98
+ debug_handle_map = self ._extract_debug_handles (m_reexport )
100
99
101
100
self .assertEqual (debug_handle_map , debug_handle_map_ref )
102
101
@@ -107,16 +106,17 @@ def test_run_decompositions_same_handle_id(self):
107
106
m = TestHelperModules .Conv2dThenConv1d ()
108
107
example_inputs = m .example_inputs ()
109
108
ep = export_for_training (m , example_inputs , strict = True )
110
- generate_numeric_debug_handle ( ep )
109
+ m = ep . module ( )
111
110
112
- self ._assert_each_node_has_debug_handle (ep )
113
- debug_handle_map_ref = self ._extract_debug_handles (ep )
111
+ self ._assert_each_node_has_debug_handle (m )
112
+ debug_handle_map_ref = self ._extract_debug_handles (m )
114
113
115
114
ep_copy = copy .copy (ep )
116
115
ep_copy = ep_copy .run_decompositions ()
116
+ m_decomposed = ep_copy .module ()
117
117
118
- self ._assert_each_node_has_debug_handle (ep_copy )
119
- debug_handle_map = self ._extract_debug_handles (ep_copy )
118
+ self ._assert_each_node_has_debug_handle (m_decomposed )
119
+ debug_handle_map = self ._extract_debug_handles (m_decomposed )
120
120
121
121
# checking the map still has the same ids, the node may change
122
122
self .assertEqual (
@@ -135,18 +135,19 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
135
135
for m in test_models :
136
136
example_inputs = m .example_inputs ()
137
137
ep = export_for_training (m , example_inputs , strict = True )
138
- generate_numeric_debug_handle ( ep )
138
+ m = ep . module ( )
139
139
140
- self ._assert_each_node_has_debug_handle (ep )
140
+ self ._assert_each_node_has_debug_handle (m )
141
141
pre_decomp_to_debug_handle_map_ref = (
142
- self ._extract_debug_handles_with_prev_decomp_op (ep )
142
+ self ._extract_debug_handles_with_prev_decomp_op (m )
143
143
)
144
144
145
145
ep_copy = copy .copy (ep )
146
146
ep_copy = ep_copy .run_decompositions ()
147
- self ._assert_each_node_has_debug_handle (ep_copy )
147
+ m_decomposed = ep_copy .module ()
148
+ self ._assert_each_node_has_debug_handle (m_decomposed )
148
149
pre_decomp_to_debug_handle_map = (
149
- self ._extract_debug_handles_with_prev_decomp_op (ep_copy )
150
+ self ._extract_debug_handles_with_prev_decomp_op (m_decomposed )
150
151
)
151
152
152
153
# checking the map still has the same ids, the node may change
@@ -158,7 +159,6 @@ def test_prepare_for_propagation_comparison(self):
158
159
m = TestHelperModules .Conv2dThenConv1d ()
159
160
example_inputs = m .example_inputs ()
160
161
ep = export_for_training (m , example_inputs , strict = True )
161
- generate_numeric_debug_handle (ep )
162
162
m = ep .module ()
163
163
m_logger = prepare_for_propagation_comparison (m )
164
164
ref = m (* example_inputs )
@@ -175,9 +175,10 @@ def test_added_node_gets_unique_id(self) -> None:
175
175
m = TestHelperModules .Conv2dThenConv1d ()
176
176
example_inputs = m .example_inputs ()
177
177
ep = export_for_training (m , example_inputs , strict = True )
178
- generate_numeric_debug_handle ( ep )
179
- ref_handles = self ._extract_debug_handles (ep )
178
+
179
+ ref_handles = self ._extract_debug_handles (ep . module () )
180
180
ref_counter = Counter (ref_handles .values ())
181
+
181
182
for k , v in ref_counter .items ():
182
183
self .assertEqual (
183
184
v ,
@@ -199,10 +200,10 @@ def test_added_node_gets_unique_id(self) -> None:
199
200
200
201
# Regenerate handles, make sure only the new relu node has a new id, and
201
202
# it doesn't clash with any of the existing ids.
202
- generate_numeric_debug_handle (ep )
203
203
204
- self ._assert_each_node_has_debug_handle (ep )
205
- handles_after_modification = self ._extract_debug_handles (ep )
204
+ m = ep .module ()
205
+ self ._assert_each_node_has_debug_handle (m )
206
+ handles_after_modification = self ._extract_debug_handles (m )
206
207
handles_counter = Counter (handles_after_modification .values ())
207
208
for name , handle in ref_handles .items ():
208
209
self .assertIn (name , handles_after_modification )
@@ -219,7 +220,7 @@ def test_added_node_gets_unique_id(self) -> None:
219
220
220
221
# Check for relu specifically. Avoid hardcoding the handle id since it
221
222
# may change with future node ordering changes.
222
- self .assertNotEqual (handles_after_modification ["relu_default" ], 0 )
223
+ self .assertNotIn (handles_after_modification ["relu_default" ], ref_counter )
223
224
self .assertEqual (handles_counter [handles_after_modification ["relu_default" ]], 1 )
224
225
225
226
0 commit comments