Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 143 additions & 42 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import numpy as np
import pandas as pd

import torch

from executorch.devtools.debug_format.et_schema import OperatorGraph, OperatorNode
from executorch.devtools.etdump.schema_flatcc import (
DebugEvent,
Expand Down Expand Up @@ -1164,33 +1166,35 @@ def _consume_etrecord(self) -> None:
index
]

def _get_aot_intermediate_outputs_and_op_names(
def _resolve_reference_graph(
self,
reference_graph: Optional[str] = None,
disable_debug_handle_valdiation: bool = False,
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
disable_debug_handle_validation: bool = False,
) -> Tuple[torch.fx.GraphModule, str]:
"""
Capture intermediate outputs only if _representative_inputs are provided
when using bundled program to create the etrecord.
Resolve the reference graph module to use for AOT operations.

This method centralizes the logic for determining which graph module to use,
ensuring consistency across all methods that need the reference graph.

Args:
reference_graph_name: Name of the graph to use as the reference for intermediate
output capture. Must be one of:
- "exported_program": Uses the ATen dialect exported program. Requires
successful debug handle backpropagation, otherwise raises an error.
- "edge_dialect_exported_program": Uses the Edge dialect program directly.
- Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward"
for post-custom-transform graph when transform_passes are applied in
to_edge_transform_and_lower with generate_etrecord=True).
disable_debug_handle_valdiation: If True, skip debug handle validation.
reference_graph: The name of the reference graph. Options:
- None: Auto-select (try exported_program first, fall back to edge_dialect)
- "exported_program": Use ATen dialect with debug handle backpropagation
- "edge_dialect_exported_program": Use Edge dialect directly
- Any other string: Look up in graph_map
disable_debug_handle_validation: If True, skip debug handle validation for
exported_program.

Returns:
Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries.
Tuple of (graph_module, resolved_graph_name) where resolved_graph_name is
the actual graph used.

Raises:
ValueError: If the specified reference_graph_name is not available or if
ValueError: If the specified reference_graph is not available or if
debug handle backpropagation fails for "exported_program".
"""
resolved_graph_name = reference_graph

# Determine the reference graph to use
if reference_graph is None or reference_graph == "exported_program":
Expand All @@ -1199,39 +1203,36 @@ def _get_aot_intermediate_outputs_and_op_names(
self._etrecord.exported_program,
self._etrecord.export_graph_id,
self._etrecord.edge_dialect_program,
disable_debug_handle_valdiation,
disable_debug_handle_validation,
):
reference_graph = "exported_program"
resolved_graph_name = "exported_program"
elif reference_graph is None:
log.warning(
"Either ATen dialect exported program is not in ETRecord, or debug handle "
"backpropagation failed. Falling back to 'edge_dialect_exported_program'."
)
reference_graph = "edge_dialect_exported_program"
resolved_graph_name = "edge_dialect_exported_program"
else:
raise ValueError(
"Cannot use 'exported_program': Debug handle backpropagation failed or exported program is unavailable. "
"Please check if the exported program is available in ETRecord, or try to disable debug handle validation."
)
if reference_graph == "edge_dialect_exported_program":
# Explicitly requested edge_dialect_exported_program

if resolved_graph_name == "edge_dialect_exported_program":
export_program = self._etrecord.edge_dialect_program
log.info(
"Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture"
"Using 'edge_dialect_exported_program' (Edge dialect) as reference graph"
)
elif reference_graph == "exported_program":
elif resolved_graph_name == "exported_program":
export_program = self._etrecord.exported_program
log.info(
"Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture"
)
log.info("Using 'exported_program' (ATen dialect) as reference graph")
else:
# Try to fetch from graph_map
# If no method name is provided (no "/" in the name), try adding "/forward" as default
lookup_name = reference_graph
if "/" not in reference_graph:
lookup_name = f"{reference_graph}/forward"
lookup_name = resolved_graph_name
if "/" not in resolved_graph_name:
lookup_name = f"{resolved_graph_name}/forward"
log.info(
f"No method name specified in '{reference_graph}', "
f"No method name specified in '{resolved_graph_name}', "
f"using '{lookup_name}' as default"
)

Expand All @@ -1240,9 +1241,7 @@ def _get_aot_intermediate_outputs_and_op_names(
and lookup_name in self._etrecord.graph_map
):
export_program = self._etrecord.graph_map[lookup_name]
log.info(
f"Using '{lookup_name}' from graph_map as reference graph for intermediate output capture"
)
log.info(f"Using '{lookup_name}' from graph_map as reference graph")
else:
available_graphs = (
list(self._etrecord.graph_map.keys())
Expand All @@ -1254,16 +1253,82 @@ def _get_aot_intermediate_outputs_and_op_names(
f"Available options: 'exported_program', 'edge_dialect_exported_program', "
f"or one of the graphs in graph_map: {available_graphs}"
)
graph_module = export_program.module()

return export_program.module(), resolved_graph_name

def _get_aot_intermediate_outputs_and_op_names(
self,
reference_graph_module: torch.fx.GraphModule,
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
"""
Capture intermediate outputs and operator name mappings from the given graph module.

Args:
reference_graph_module: The resolved reference graph module to use.

Returns:
Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries.
"""
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
graph_module
reference_graph_module
)
capturer = IntermediateOutputCapturer(graph_module)
capturer = IntermediateOutputCapturer(reference_graph_module)
aot_intermediate_outputs = capturer.run_and_capture(
self._etrecord._representative_inputs
)
return aot_intermediate_outputs, aot_debug_handle_to_op_name

def _get_aot_debug_handle_to_stack_traces(
self,
reference_graph_module: torch.fx.GraphModule,
resolved_graph_name: str,
) -> Dict[DebugHandle, Dict[str, Optional[str]]]:
"""
Get a mapping from debug handle to stack traces from the given graph module.

Args:
reference_graph_module: The resolved reference graph module to use.
resolved_graph_name: The name of the graph (for warning messages).

Returns:
Dict[DebugHandle, Dict[str, Optional[str]]]: A dictionary mapping debug handles
to dictionaries of {op_name: stack_trace}.
"""
from executorch.devtools.inspector._inspector_utils import NodeFilter

node_filters = [
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
]

result: Dict[DebugHandle, Dict[str, Optional[str]]] = {}
has_any_stack_trace = False

for node in reference_graph_module.graph.nodes:
if all(filter.matches(node) for filter in node_filters):
debug_handle = node.meta["debug_handle"]
key = (
(debug_handle,)
if isinstance(debug_handle, int)
else tuple(debug_handle)
)
stack_trace = node.meta.get("stack_trace")
if stack_trace is not None:
has_any_stack_trace = True

if key in result:
result[key][node.name] = stack_trace
else:
result[key] = {node.name: stack_trace}

if not has_any_stack_trace and result:
log.warning(
f"No stack traces found in reference_graph '{resolved_graph_name}'. "
"The 'stacktraces' column will contain None values. "
"Ensure the model was exported with stack trace information preserved."
)

return result

# TODO: Make it more extensible to further merge overlapping debug handles
def _get_runtime_intermediate_outputs_and_op_names(
self,
Expand Down Expand Up @@ -1506,17 +1571,29 @@ def calculate_numeric_gap(

Returns:
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
The DataFrame includes a "stacktraces" column where each entry is a dict mapping operator names to their stack traces.
"""
# First, resolve the reference graph to use
reference_graph_module, resolved_graph_name = self._resolve_reference_graph(
reference_graph,
disable_debug_handle_valdiation,
)

# Get intermediate outputs and op names from the resolved graph
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
self._get_aot_intermediate_outputs_and_op_names(
reference_graph,
disable_debug_handle_valdiation,
)
self._get_aot_intermediate_outputs_and_op_names(reference_graph_module)
)
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
raise ValueError(
"Missing etrecord or missing representative inputs within etrecord, both of which are required for calculating numerical gap"
)

# Get the stack trace mapping from the resolved graph
aot_debug_handle_to_stack_traces = self._get_aot_debug_handle_to_stack_traces(
reference_graph_module,
resolved_graph_name,
)

# The runtime_op_names will be used later to map runtime debug_handle to op_name
runtime_intermediate_outputs, runtime_debug_handle_to_op_names = (
self._get_runtime_intermediate_outputs_and_op_names()
Expand All @@ -1543,8 +1620,32 @@ def calculate_numeric_gap(
raise ValueError(f"Unsupported distance metric {distance!r}")

# Delegate to comparator's compare method (includes preprocessing)
return comparator.compare(
df = comparator.compare(
mapping,
aot_debug_handle_to_op_names,
runtime_debug_handle_to_op_names,
)

# Add stacktraces column by looking up each row's debug handle
# We need to map from aot_ops back to debug handles to get stack traces
def get_stacktraces_for_row(aot_ops: List[str]) -> Dict[str, Optional[str]]:
"""Find stack traces for the given aot_ops by looking up in all debug handles."""
result: Dict[str, Optional[str]] = {}
for op_name in aot_ops:
# Search through debug handle mappings to find the stack trace for this op
for (
_,
stack_traces_dict,
) in aot_debug_handle_to_stack_traces.items():
if op_name in stack_traces_dict:
result[op_name] = stack_traces_dict[op_name]
break
else:
# Op not found in any debug handle's stack traces
result[op_name] = None
return result

if len(df) > 0:
df["stacktraces"] = df["aot_ops"].apply(get_stacktraces_for_row)

return df
Loading
Loading