Skip to content

Replace debug handle with from_node to trace operator transformation #2339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions test/quantization/pt2e/test_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests

from torchao.quantization.pt2e import (
generate_numeric_debug_handle,
prepare_for_propagation_comparison,
)
from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase
Expand All @@ -35,34 +34,35 @@ def test_simple(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
debug_handle_map = self._extract_debug_handles(ep)
m = ep.module()
self._assert_each_node_has_debug_handle(m)
debug_handle_map = self._extract_debug_handles(m)

self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))

@unittest.skip("debug flow not working on model with conditional control flow")
def test_control_flow(self):
m = TestHelperModules.ControlFlow()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

self._assert_each_node_has_debug_handle(ep)
debug_handle_map = self._extract_debug_handles(ep)
self._assert_each_node_has_debug_handle(m)
debug_handle_map = self._extract_debug_handles(m)

self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))

def test_copy_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = torch.export.export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

self._assert_each_node_has_debug_handle(ep)
debug_handle_map_ref = self._extract_debug_handles(ep)
self._assert_each_node_has_debug_handle(m)
debug_handle_map_ref = self._extract_debug_handles(m)

ep_copy = copy.copy(ep)
debug_handle_map = self._extract_debug_handles(ep_copy)
debug_handle_map = self._extract_debug_handles(ep_copy.module())

self._assert_each_node_has_debug_handle(ep)
self.assertEqual(debug_handle_map, debug_handle_map_ref)
Expand All @@ -71,13 +71,12 @@ def test_deepcopy_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = torch.export.export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)

debug_handle_map_ref = self._extract_debug_handles(ep)
debug_handle_map_ref = self._extract_debug_handles(ep.module())
ep_copy = copy.deepcopy(ep)
debug_handle_map = self._extract_debug_handles(ep_copy)
debug_handle_map = self._extract_debug_handles(ep_copy.module())

self._assert_each_node_has_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep.module())
self.assertEqual(debug_handle_map, debug_handle_map_ref)

@unittest.skip(
Expand All @@ -87,16 +86,16 @@ def test_re_export_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

self._assert_each_node_has_debug_handle(ep)
debug_handle_map_ref = self._extract_debug_handles(ep)
self._assert_each_node_has_debug_handle(m)
debug_handle_map_ref = self._extract_debug_handles(m)

ep_reexport = export_for_training(m, example_inputs, strict=True)
m_reexport = ep_reexport.module()

self._assert_each_node_has_debug_handle(ep_reexport)
debug_handle_map = self._extract_debug_handles(ep_reexport)
self._assert_each_node_has_debug_handle(m_reexport)
debug_handle_map = self._extract_debug_handles(m_reexport)

self.assertEqual(debug_handle_map, debug_handle_map_ref)

Expand All @@ -107,16 +106,17 @@ def test_run_decompositions_same_handle_id(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

self._assert_each_node_has_debug_handle(ep)
debug_handle_map_ref = self._extract_debug_handles(ep)
self._assert_each_node_has_debug_handle(m)
debug_handle_map_ref = self._extract_debug_handles(m)

ep_copy = copy.copy(ep)
ep_copy = ep_copy.run_decompositions()
m_decomposed = ep_copy.module()

self._assert_each_node_has_debug_handle(ep_copy)
debug_handle_map = self._extract_debug_handles(ep_copy)
self._assert_each_node_has_debug_handle(m_decomposed)
debug_handle_map = self._extract_debug_handles(m_decomposed)

# checking the map still has the same ids, the node may change
self.assertEqual(
Expand All @@ -135,18 +135,19 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
for m in test_models:
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

self._assert_each_node_has_debug_handle(ep)
self._assert_each_node_has_debug_handle(m)
pre_decomp_to_debug_handle_map_ref = (
self._extract_debug_handles_with_prev_decomp_op(ep)
self._extract_debug_handles_with_prev_decomp_op(m)
)

ep_copy = copy.copy(ep)
ep_copy = ep_copy.run_decompositions()
self._assert_each_node_has_debug_handle(ep_copy)
m_decomposed = ep_copy.module()
self._assert_each_node_has_debug_handle(m_decomposed)
pre_decomp_to_debug_handle_map = (
self._extract_debug_handles_with_prev_decomp_op(ep_copy)
self._extract_debug_handles_with_prev_decomp_op(m_decomposed)
)

# checking the map still has the same ids, the node may change
Expand All @@ -158,7 +159,6 @@ def test_prepare_for_propagation_comparison(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_logger = prepare_for_propagation_comparison(m)
ref = m(*example_inputs)
Expand All @@ -175,9 +175,10 @@ def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
ref_handles = self._extract_debug_handles(ep)

ref_handles = self._extract_debug_handles(ep.module())
ref_counter = Counter(ref_handles.values())

for k, v in ref_counter.items():
self.assertEqual(
v,
Expand All @@ -199,10 +200,10 @@ def test_added_node_gets_unique_id(self) -> None:

# Regenerate handles, make sure only the new relu node has a new id, and
# it doesn't clash with any of the existing ids.
generate_numeric_debug_handle(ep)

self._assert_each_node_has_debug_handle(ep)
handles_after_modification = self._extract_debug_handles(ep)
m = ep.module()
self._assert_each_node_has_debug_handle(m)
handles_after_modification = self._extract_debug_handles(m)
handles_counter = Counter(handles_after_modification.values())
for name, handle in ref_handles.items():
self.assertIn(name, handles_after_modification)
Expand All @@ -219,7 +220,7 @@ def test_added_node_gets_unique_id(self) -> None:

# Check for relu specifically. Avoid hardcoding the handle id since it
# may change with future node ordering changes.
self.assertNotEqual(handles_after_modification["relu_default"], 0)
self.assertNotIn(handles_after_modification["relu_default"], ref_counter)
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)


Expand Down
5 changes: 3 additions & 2 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)

import torchao
from torchao.quantization.pt2e import ObserverOrFakeQuantize, observer
from torchao.quantization.pt2e import FROM_NODE_KEY, ObserverOrFakeQuantize, observer
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
Expand Down Expand Up @@ -1499,7 +1499,8 @@ def forward(self, x):
for n in m.graph.nodes:
if n.op == "get_attr" and "frozen_param" in n.target:
for key in n.meta:
self.assertEqual(n.meta[key], weight_meta[key])
if key != FROM_NODE_KEY:
self.assertEqual(n.meta[key], weight_meta[key])

def test_save_load(self):
"""Test save/load a quantized model"""
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/pt2e/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torchao.quantization.pt2e._numeric_debugger import ( # noqa: F401
CUSTOM_KEY,
FROM_NODE_KEY,
NUMERIC_DEBUG_HANDLE_KEY,
compare_results,
extract_results_from_loggers,
Expand Down Expand Up @@ -132,6 +133,7 @@
"generate_numeric_debug_handle",
"CUSTOM_KEY",
"NUMERIC_DEBUG_HANDLE_KEY",
"FROM_NODE_KEY",
"prepare_for_propagation_comparison",
"extract_results_from_loggers",
"compare_results",
Expand Down
75 changes: 66 additions & 9 deletions torchao/quantization/pt2e/_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
from torch.fx import GraphModule, Node
from torch.nn import functional as F

from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

if TORCH_VERSION_AT_LEAST_2_6:
from torch.fx.traceback import NodeSource

from .graph_utils import bfs_trace_with_node_process

NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
CUSTOM_KEY = "custom"
FROM_NODE_KEY = "from_node"

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,6 +84,56 @@ def _assign_debug_handle(node: torch.fx.Node) -> None:
bfs_trace_with_node_process(ep, _assign_debug_handle)


def _get_greatest_ancestor_node_source(node: Node) -> Optional["NodeSource"]:
if (node_source := node.meta.get(FROM_NODE_KEY)) is None:
return None

node_source = node_source[-1]

while len(node_source.from_node) > 0:
node_source = node_source.from_node[-1]

return node_source


def _generate_debug_handle_from_node(node: Node) -> Optional[int]:
"""
Generate a debug handle based on node's oldest ancestor node's name
and graph id, or return None if the node does not need to be traced.

This is a temporary function for migrating node tracing infra from
using debug handle to node.meta["from_node"]. The infrastructure will
depend on node.meta["from_node"] directly in the future, without the need
of debug handle as intermediate variable.
"""

if node.op == "placeholder" or node.op == "output":
# placeholder and output nodes don't have debug handle
return None

if (
FROM_NODE_KEY not in node.meta
or node.meta[FROM_NODE_KEY] is None
or node.meta[FROM_NODE_KEY][-1].pass_name == "ExportedProgram.module().unlift()"
):
# This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle
return None

greatest_ancestor_node_source = _get_greatest_ancestor_node_source(node)

if greatest_ancestor_node_source is None:
# This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle
return None

if greatest_ancestor_node_source.pass_name == "ExportedProgram.module().unlift()":
# uplifted nodes don't have debug handle
return None

return hash(
greatest_ancestor_node_source.name + str(greatest_ancestor_node_source.graph_id)
)


def _detach(x: object) -> object:
detached: object = None
if isinstance(x, torch.Tensor):
Expand Down Expand Up @@ -187,23 +243,24 @@ def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:


def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
"""Add output loggers to node that has numeric_debug_handle
"""Add output loggers to unlifted node

Args:
model (GraphModule): original model
Returns:
a model with output loggers for all nodes that has numeric_debug_handle_id
a model with output loggers for all unlifted nodes
"""
if not TORCH_VERSION_AT_LEAST_2_6:
log.warning(
"prepare_for_propagation_comparison is only supported for PyTorch 2.6+"
)
return model

# don't change the original model
model = copy.deepcopy(model)
for n in model.graph.nodes:
if (
CUSTOM_KEY not in n.meta
or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY]
):
continue
numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
_insert_logger(model, n, numeric_debug_handle)
if (numeric_debug_handle := _generate_debug_handle_from_node(n)) is not None:
_insert_logger(model, n, numeric_debug_handle)

model.recompile()
return model
Expand Down
Loading
Loading