Skip to content

Commit 9a3ea41

Browse files
Juntian777facebook-github-bot
authored andcommitted
Gather intermediate output from Events (#11454)
Summary: This diff gathers the debug_handles and runtime output from every "real" event, skipping those like OPERATOR_CALL and initialization events. It maps these debug_handles to their corresponding outputs and ensures that only the last one is retained for operators with the same debug_handles. Differential Revision: D76105355
1 parent c2aa614 commit 9a3ea41

File tree

3 files changed

+127
-3
lines changed

3 files changed

+127
-3
lines changed

devtools/inspector/_inspector.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
EDGE_DIALECT_GRAPH_KEY,
4848
EXCLUDED_COLUMNS_WHEN_PRINTING,
4949
EXCLUDED_EVENTS_WHEN_PRINTING,
50+
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
5051
find_populated_event,
5152
FORWARD,
5253
gen_etdump_object,
@@ -1149,6 +1150,33 @@ def _consume_etrecord(self) -> None:
11491150
self._etrecord._representative_inputs
11501151
)
11511152

1153+
# TODO: Make it more extensible to further merge overlapping debug handles
1154+
def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
1155+
"""
1156+
Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
1157+
from the event blocks. These outputs will be processed later to merge overlapping debug handles.
1158+
"""
1159+
debug_handle_to_output = {}
1160+
for event_block in self.event_blocks:
1161+
for event in event_block.events:
1162+
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
1163+
if event.name in EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT or not event.op_types:
1164+
continue
1165+
# Normalize debug_handles to a tuple
1166+
debug_handles = event.debug_handles
1167+
if isinstance(debug_handles, int):
1168+
debug_handles = (debug_handles,)
1169+
else:
1170+
debug_handles = tuple(debug_handles)
1171+
current_entry = debug_handle_to_output.get(debug_handles, (-1, None))
1172+
# When event has same debug handles, only keep the one with the largest instruction id
1173+
if event._instruction_id > current_entry[0]:
1174+
debug_handle_to_output[debug_handles] = (
1175+
event._instruction_id,
1176+
event.debug_data,
1177+
)
1178+
return {k: v[1] for k, v in debug_handle_to_output.items()}
1179+
11521180
def to_dataframe(
11531181
self,
11541182
include_units: bool = True,

devtools/inspector/_inspector_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
]
5353
EXCLUDED_EVENTS_WHEN_PRINTING = {"OPERATOR_CALL"}
5454

55+
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT = {"OPERATOR_CALL"}
5556

5657
class TimeScale(Enum):
5758
NS = "ns"

devtools/inspector/tests/inspector_test.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
# pyre-unsafe
88

99
import copy
10+
import torch
1011
import random
1112
import statistics
1213
import tempfile
1314
import unittest
1415
from contextlib import redirect_stdout
1516

16-
from typing import Callable, List
17+
from typing import Callable, List, Union
1718

1819
from unittest.mock import patch
1920

@@ -56,7 +57,7 @@
5657

5758
OP_TYPE = "aten::add"
5859
EVENT_BLOCK_NAME = "block_0"
59-
EVENTS_SIZE = 5
60+
EVENTS_SIZE = 10
6061
RAW_DATA_SIZE = 10
6162
ETDUMP_PATH = "unittest_etdump_path"
6263
ETRECORD_PATH = "unittest_etrecord_path"
@@ -535,17 +536,111 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
535536
)
536537
)
537538

539+
def test_get_runtime_intermediate_outputs(self):
540+
# Create a context manager to patch functions called by Inspector.__init__
541+
with patch.object(
542+
_inspector, "parse_etrecord", return_value=None
543+
), patch.object(
544+
_inspector, "gen_etdump_object", return_value=None
545+
), patch.object(
546+
EventBlock, "_gen_from_etdump"
547+
), patch.object(
548+
_inspector, "gen_graphs_from_etrecord"
549+
):
550+
# Call the constructor of Inspector
551+
inspector_instance = Inspector(
552+
etdump_path=ETDUMP_PATH,
553+
etrecord=ETRECORD_PATH,
554+
)
555+
556+
# The mock inspector instance starts with having an empty event blocks list.
557+
# Add pre-defined event blocks to test _get_runtime_outputs().
558+
inspector_instance.event_blocks = [
559+
EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
560+
]
561+
562+
runtime_outputs = inspector_instance._get_runtime_intermediate_outputs()
563+
# This output should be a dictionary with 5 keys
564+
self.assertEqual(len(runtime_outputs), 5, )
565+
# Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty)
566+
self.assertNotIn((0,), runtime_outputs)
567+
self.assertNotIn((1,), runtime_outputs)
568+
569+
# Same debug_handle but different instruction_id, should record the last one
570+
self.assertIn((4,), runtime_outputs)
571+
self.assertTrue(torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0])))
572+
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
573+
for key in range(5, 9):
574+
self.assertIn((key,), runtime_outputs)
575+
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
576+
538577
def _gen_random_float_list(self) -> List[float]:
539578
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
540579

580+
def _gen_random_runtime_output(self) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
581+
return list(torch.randn(RAW_DATA_SIZE))
582+
541583
def _gen_random_events(self) -> List[Event]:
542584
events = []
543-
for i in range(EVENTS_SIZE):
585+
for i in range(2):
586+
events.append(
587+
# OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
588+
Event(
589+
name="OPERATOR_CALL",
590+
op_types=[OP_TYPE],
591+
perf_data=PerfData(self._gen_random_float_list()),
592+
debug_handles = i * 2,
593+
_instruction_id = i * 2,
594+
debug_data = self._gen_random_runtime_output()
595+
)
596+
)
597+
events.append(
598+
# op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
599+
Event(
600+
name=f"op_{i}",
601+
op_types=[],
602+
perf_data=PerfData(self._gen_random_float_list()),
603+
debug_handles = i * 2 + 1,
604+
_instruction_id = i * 2 + 1,
605+
debug_data = self._gen_random_runtime_output()
606+
)
607+
)
608+
609+
# op_2 with debug_hanldes/instruction_id 4
610+
events.append(
611+
Event(
612+
name=f"op_2",
613+
op_types=[OP_TYPE],
614+
perf_data=PerfData(self._gen_random_float_list()),
615+
debug_handles = 4,
616+
debug_data = [torch.tensor([1.0, 2.0, 3.0])],
617+
_instruction_id = 4
618+
619+
)
620+
)
621+
# op_3 also with debug_hanldes 4 but with instruction_id 5
622+
events.append(
623+
Event(
624+
name=f"op_3",
625+
op_types=[OP_TYPE],
626+
perf_data=PerfData(self._gen_random_float_list()),
627+
debug_handles = 4,
628+
debug_data = [torch.tensor([4.0, 5.0, 6.0])],
629+
_instruction_id = 5
630+
631+
)
632+
)
633+
634+
# op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
635+
for i in range(4, EVENTS_SIZE - 2):
544636
events.append(
545637
Event(
546638
name=f"op_{i}",
547639
op_types=[OP_TYPE],
548640
perf_data=PerfData(self._gen_random_float_list()),
641+
debug_handles = i + 1,
642+
debug_data = self._gen_random_runtime_output(),
643+
_instruction_id = i + 2
549644
)
550645
)
551646
return events

0 commit comments

Comments
 (0)