Skip to content

Commit 3f11883

Browse files
authored
Gather intermediate output from Events
Differential Revision: D76105355 Pull Request resolved: #11454
1 parent 00ca8ff commit 3f11883

File tree

3 files changed

+135
-3
lines changed

3 files changed

+135
-3
lines changed

devtools/inspector/_inspector.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
display_or_print_df,
4747
EDGE_DIALECT_GRAPH_KEY,
4848
EXCLUDED_COLUMNS_WHEN_PRINTING,
49+
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
4950
EXCLUDED_EVENTS_WHEN_PRINTING,
5051
find_populated_event,
5152
FORWARD,
@@ -1149,6 +1150,36 @@ def _consume_etrecord(self) -> None:
11491150
self._etrecord._representative_inputs
11501151
)
11511152

1153+
# TODO: Make it more extensible to further merge overlapping debug handles
1154+
def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
1155+
"""
1156+
Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
1157+
from the event blocks. These outputs will be processed later to merge overlapping debug handles.
1158+
"""
1159+
debug_handle_to_output = {}
1160+
for event_block in self.event_blocks:
1161+
for event in event_block.events:
1162+
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
1163+
if (
1164+
event.name in EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT
1165+
or not event.op_types
1166+
):
1167+
continue
1168+
# Normalize debug_handles to a tuple
1169+
debug_handles = event.debug_handles
1170+
if isinstance(debug_handles, int):
1171+
debug_handles = (debug_handles,)
1172+
else:
1173+
debug_handles = tuple(debug_handles)
1174+
current_entry = debug_handle_to_output.get(debug_handles, (-1, None))
1175+
# When event has same debug handles, only keep the one with the largest instruction id
1176+
if event._instruction_id > current_entry[0]:
1177+
debug_handle_to_output[debug_handles] = (
1178+
event._instruction_id,
1179+
event.debug_data,
1180+
)
1181+
return {k: v[1] for k, v in debug_handle_to_output.items()}
1182+
11521183
def to_dataframe(
11531184
self,
11541185
include_units: bool = True,

devtools/inspector/_inspector_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
]
5353
EXCLUDED_EVENTS_WHEN_PRINTING = {"OPERATOR_CALL"}
5454

55+
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT = {"OPERATOR_CALL"}
56+
5557

5658
class TimeScale(Enum):
5759
NS = "ns"

devtools/inspector/tests/inspector_test.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import unittest
1414
from contextlib import redirect_stdout
1515

16-
from typing import Callable, List
16+
from typing import Callable, List, Union
1717

1818
from unittest.mock import patch
1919

@@ -56,7 +56,7 @@
5656

5757
OP_TYPE = "aten::add"
5858
EVENT_BLOCK_NAME = "block_0"
59-
EVENTS_SIZE = 5
59+
EVENTS_SIZE = 10
6060
RAW_DATA_SIZE = 10
6161
ETDUMP_PATH = "unittest_etdump_path"
6262
ETRECORD_PATH = "unittest_etrecord_path"
@@ -535,17 +535,116 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
535535
)
536536
)
537537

538+
def test_get_runtime_intermediate_outputs(self):
539+
# Create a context manager to patch functions called by Inspector.__init__
540+
with patch.object(
541+
_inspector, "parse_etrecord", return_value=None
542+
), patch.object(
543+
_inspector, "gen_etdump_object", return_value=None
544+
), patch.object(
545+
EventBlock, "_gen_from_etdump"
546+
), patch.object(
547+
_inspector, "gen_graphs_from_etrecord"
548+
):
549+
# Call the constructor of Inspector
550+
inspector_instance = Inspector(
551+
etdump_path=ETDUMP_PATH,
552+
etrecord=ETRECORD_PATH,
553+
)
554+
555+
# The mock inspector instance starts with having an empty event blocks list.
556+
# Add pre-defined event blocks to test _get_runtime_outputs().
557+
inspector_instance.event_blocks = [
558+
EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
559+
]
560+
561+
runtime_outputs = inspector_instance._get_runtime_intermediate_outputs()
562+
# This output should be a dictionary with 5 keys
563+
self.assertEqual(
564+
len(runtime_outputs),
565+
5,
566+
)
567+
# Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty)
568+
self.assertNotIn((0,), runtime_outputs)
569+
self.assertNotIn((1,), runtime_outputs)
570+
571+
# Same debug_handle but different instruction_id, should record the last one
572+
self.assertIn((4,), runtime_outputs)
573+
self.assertTrue(
574+
torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
575+
)
576+
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
577+
for key in range(5, 9):
578+
self.assertIn((key,), runtime_outputs)
579+
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
580+
538581
def _gen_random_float_list(self) -> List[float]:
539582
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
540583

584+
def _gen_random_runtime_output(
585+
self,
586+
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
587+
return list(torch.randn(RAW_DATA_SIZE))
588+
541589
def _gen_random_events(self) -> List[Event]:
542590
events = []
543-
for i in range(EVENTS_SIZE):
591+
for i in range(2):
592+
events.append(
593+
# OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
594+
Event(
595+
name="OPERATOR_CALL",
596+
op_types=[OP_TYPE],
597+
perf_data=PerfData(self._gen_random_float_list()),
598+
debug_handles=i * 2,
599+
_instruction_id=i * 2,
600+
debug_data=self._gen_random_runtime_output(),
601+
)
602+
)
603+
events.append(
604+
# op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
605+
Event(
606+
name=f"op_{i}",
607+
op_types=[],
608+
perf_data=PerfData(self._gen_random_float_list()),
609+
debug_handles=i * 2 + 1,
610+
_instruction_id=i * 2 + 1,
611+
debug_data=self._gen_random_runtime_output(),
612+
)
613+
)
614+
615+
# op_2 with debug_hanldes/instruction_id 4
616+
events.append(
617+
Event(
618+
name="op_2",
619+
op_types=[OP_TYPE],
620+
perf_data=PerfData(self._gen_random_float_list()),
621+
debug_handles=4,
622+
debug_data=[torch.tensor([1.0, 2.0, 3.0])],
623+
_instruction_id=4,
624+
)
625+
)
626+
# op_3 also with debug_hanldes 4 but with instruction_id 5
627+
events.append(
628+
Event(
629+
name="op_3",
630+
op_types=[OP_TYPE],
631+
perf_data=PerfData(self._gen_random_float_list()),
632+
debug_handles=4,
633+
debug_data=[torch.tensor([4.0, 5.0, 6.0])],
634+
_instruction_id=5,
635+
)
636+
)
637+
638+
# op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
639+
for i in range(4, EVENTS_SIZE - 2):
544640
events.append(
545641
Event(
546642
name=f"op_{i}",
547643
op_types=[OP_TYPE],
548644
perf_data=PerfData(self._gen_random_float_list()),
645+
debug_handles=i + 1,
646+
debug_data=self._gen_random_runtime_output(),
647+
_instruction_id=i + 2,
549648
)
550649
)
551650
return events

0 commit comments

Comments
 (0)