Skip to content

Gather intermediate output from Events #11454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
display_or_print_df,
EDGE_DIALECT_GRAPH_KEY,
EXCLUDED_COLUMNS_WHEN_PRINTING,
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
EXCLUDED_EVENTS_WHEN_PRINTING,
find_populated_event,
FORWARD,
Expand Down Expand Up @@ -1149,6 +1150,36 @@ def _consume_etrecord(self) -> None:
self._etrecord._representative_inputs
)

# TODO: Make it more extensible to further merge overlapping debug handles
def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
"""
Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
from the event blocks. These outputs will be processed later to merge overlapping debug handles.
"""
debug_handle_to_output = {}
for event_block in self.event_blocks:
for event in event_block.events:
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
if (
event.name in EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT
or not event.op_types
):
continue
# Normalize debug_handles to a tuple
debug_handles = event.debug_handles
if isinstance(debug_handles, int):
debug_handles = (debug_handles,)
else:
debug_handles = tuple(debug_handles)
current_entry = debug_handle_to_output.get(debug_handles, (-1, None))
# When event has same debug handles, only keep the one with the largest instruction id
if event._instruction_id > current_entry[0]:
debug_handle_to_output[debug_handles] = (
event._instruction_id,
event.debug_data,
)
return {k: v[1] for k, v in debug_handle_to_output.items()}

def to_dataframe(
self,
include_units: bool = True,
Expand Down
2 changes: 2 additions & 0 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
]
EXCLUDED_EVENTS_WHEN_PRINTING = {"OPERATOR_CALL"}

EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT = {"OPERATOR_CALL"}


class TimeScale(Enum):
NS = "ns"
Expand Down
105 changes: 102 additions & 3 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import unittest
from contextlib import redirect_stdout

from typing import Callable, List
from typing import Callable, List, Union

from unittest.mock import patch

Expand Down Expand Up @@ -56,7 +56,7 @@

OP_TYPE = "aten::add"
EVENT_BLOCK_NAME = "block_0"
EVENTS_SIZE = 5
EVENTS_SIZE = 10
RAW_DATA_SIZE = 10
ETDUMP_PATH = "unittest_etdump_path"
ETRECORD_PATH = "unittest_etrecord_path"
Expand Down Expand Up @@ -535,17 +535,116 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
)
)

def test_get_runtime_intermediate_outputs(self):
# Create a context manager to patch functions called by Inspector.__init__
with patch.object(
_inspector, "parse_etrecord", return_value=None
), patch.object(
_inspector, "gen_etdump_object", return_value=None
), patch.object(
EventBlock, "_gen_from_etdump"
), patch.object(
_inspector, "gen_graphs_from_etrecord"
):
# Call the constructor of Inspector
inspector_instance = Inspector(
etdump_path=ETDUMP_PATH,
etrecord=ETRECORD_PATH,
)

# The mock inspector instance starts with having an empty event blocks list.
# Add pre-defined event blocks to test _get_runtime_outputs().
inspector_instance.event_blocks = [
EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
]

runtime_outputs = inspector_instance._get_runtime_intermediate_outputs()
# This output should be a dictionary with 5 keys
self.assertEqual(
len(runtime_outputs),
5,
)
# Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty)
self.assertNotIn((0,), runtime_outputs)
self.assertNotIn((1,), runtime_outputs)

# Same debug_handle but different instruction_id, should record the last one
self.assertIn((4,), runtime_outputs)
self.assertTrue(
torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
)
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
for key in range(5, 9):
self.assertIn((key,), runtime_outputs)
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)

def _gen_random_float_list(self) -> List[float]:
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]

def _gen_random_runtime_output(
self,
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
return list(torch.randn(RAW_DATA_SIZE))

def _gen_random_events(self) -> List[Event]:
events = []
for i in range(EVENTS_SIZE):
for i in range(2):
events.append(
# OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
Event(
name="OPERATOR_CALL",
op_types=[OP_TYPE],
perf_data=PerfData(self._gen_random_float_list()),
debug_handles=i * 2,
_instruction_id=i * 2,
debug_data=self._gen_random_runtime_output(),
)
)
events.append(
# op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
Event(
name=f"op_{i}",
op_types=[],
perf_data=PerfData(self._gen_random_float_list()),
debug_handles=i * 2 + 1,
_instruction_id=i * 2 + 1,
debug_data=self._gen_random_runtime_output(),
)
)

# op_2 with debug_hanldes/instruction_id 4
events.append(
Event(
name="op_2",
op_types=[OP_TYPE],
perf_data=PerfData(self._gen_random_float_list()),
debug_handles=4,
debug_data=[torch.tensor([1.0, 2.0, 3.0])],
_instruction_id=4,
)
)
# op_3 also with debug_hanldes 4 but with instruction_id 5
events.append(
Event(
name="op_3",
op_types=[OP_TYPE],
perf_data=PerfData(self._gen_random_float_list()),
debug_handles=4,
debug_data=[torch.tensor([4.0, 5.0, 6.0])],
_instruction_id=5,
)
)

# op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
for i in range(4, EVENTS_SIZE - 2):
events.append(
Event(
name=f"op_{i}",
op_types=[OP_TYPE],
perf_data=PerfData(self._gen_random_float_list()),
debug_handles=i + 1,
debug_data=self._gen_random_runtime_output(),
_instruction_id=i + 2,
)
)
return events
Loading