Skip to content

Commit 222d9e3

Browse files
authored
Add inspector numeric gap calculation between AOT and runtime intermediate outputs
Differential Revision: D76831086 Pull Request resolved: #11855
1 parent d83636d commit 222d9e3

File tree

5 files changed

+142
-9
lines changed

5 files changed

+142
-9
lines changed

devtools/inspector/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ python_library(
1919
"//executorch/devtools/etrecord:etrecord",
2020
"//executorch/exir:lib",
2121
"//executorch/devtools/inspector:intermediate_output_capturer",
22+
"//executorch/devtools/inspector/numerical_comparator:lib",
2223
],
2324
)
2425

devtools/inspector/_inspector.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
inflate_runtime_output,
5656
is_debug_output,
5757
is_inference_output_equal,
58+
map_runtime_aot_intermediate_outputs,
5859
ProgramOutput,
5960
RESERVED_FRAMEWORK_EVENT_NAMES,
6061
TimeScale,
@@ -63,6 +64,10 @@
6364
from executorch.devtools.inspector._intermediate_output_capturer import (
6465
IntermediateOutputCapturer,
6566
)
67+
from executorch.devtools.inspector.numerical_comparator import (
68+
L1Comparator,
69+
MSEComparator,
70+
)
6671
from executorch.exir import ExportedProgram
6772

6873

@@ -1337,3 +1342,50 @@ def get_exported_program(
13371342
if graph is None
13381343
else self._etrecord.graph_map.get(graph)
13391344
)
1345+
1346+
def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
1347+
"""
1348+
Compares logged intermediate outputs from the exported graph (in ETRecord)
1349+
with runtime outputs (in ETDump) using a user-specific numerical comparator.
1350+
1351+
Args:
1352+
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
1353+
1354+
Returns:
1355+
pd.DataFrame: A DataFrame listing corresponding operator outputs from
1356+
both stages and their computed numerical gaps.
1357+
"""
1358+
if self._aot_intermediate_outputs is None:
1359+
raise ValueError(
1360+
"The aot intermediate outputs is required but not populated."
1361+
)
1362+
mapping = map_runtime_aot_intermediate_outputs(
1363+
self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs()
1364+
)
1365+
metric = distance.strip().upper()
1366+
if metric == "MSE":
1367+
comparator = MSEComparator()
1368+
elif metric == "L1":
1369+
comparator = L1Comparator()
1370+
else:
1371+
raise ValueError(f"Unsupported distance metric {distance!r}")
1372+
1373+
rows = []
1374+
for (aot_debug_handle, aot_intermediate_output), (
1375+
runtime_debug_handle,
1376+
runtime_intermediate_output,
1377+
) in mapping.items():
1378+
if aot_intermediate_output is None or runtime_intermediate_output is None:
1379+
continue
1380+
rows.append(
1381+
{
1382+
"aot_debug_handle": aot_debug_handle,
1383+
"aot_intermediate_output": aot_intermediate_output,
1384+
"runtime_debug_handle": runtime_debug_handle,
1385+
"runtime_intermediate_output": runtime_intermediate_output,
1386+
"gap": comparator.compare(
1387+
aot_intermediate_output, runtime_intermediate_output
1388+
),
1389+
}
1390+
)
1391+
return pd.DataFrame(rows)

devtools/inspector/_inspector_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import math
1010
import sys
11+
from collections.abc import Sequence
1112
from dataclasses import dataclass
1213
from enum import Enum
1314
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
@@ -676,17 +677,25 @@ def map_runtime_aot_intermediate_outputs(
676677
# Map only if both AOT and runtime data are present.
677678
if len(aot_list) != 0 and len(runtime_list) != 0:
678679
# Combine aot debug handles into a single key
679-
aot_combined_debug_handle, aot_output = (
680+
aot_combined_debug_handle, aot_intermediate_output = (
680681
_combine_overlapped_intermediate_outputs(aot_list)
681682
)
682683
# Combine runtime debug handles into a single key
683-
runtime_combined_debug_handle, runtime_output = (
684+
runtime_combined_debug_handle, runtime_intermediate_output = (
684685
_combine_overlapped_intermediate_outputs(runtime_list)
685686
)
687+
# List can't be used as a key, so convert to tuple
688+
if isinstance(aot_intermediate_output, list):
689+
aot_intermediate_output = tuple(aot_intermediate_output)
690+
# runtime follow the same format as aot, so it's safe to convert to tuple
691+
if isinstance(runtime_intermediate_output, list):
692+
runtime_intermediate_output = tuple(runtime_intermediate_output)
686693
# Create a mapping between runtime and aot
687-
aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = (
694+
aot_runtime_mapping[
695+
(aot_combined_debug_handle, aot_intermediate_output)
696+
] = (
688697
runtime_combined_debug_handle,
689-
runtime_output,
698+
runtime_intermediate_output,
690699
)
691700

692701
return aot_runtime_mapping
@@ -698,7 +707,7 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
698707
This function handles the following types of input:
699708
- Scalar (int or float): Converts to a tensor with a single element.
700709
- Tensor: Converts to a float64 tensor on CPU.
701-
- List of Tensors: Stacks the tensors into a single float64 tensor on CPU.
710+
- Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
702711
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
703712
Parameters:
704713
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
@@ -709,8 +718,8 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
709718
ValueError: If the input_data cannot be converted to a tensor.
710719
"""
711720
try:
712-
# Check if the input is a list of tensors
713-
if isinstance(input_data, list):
721+
# Check if the input is a Sequence of tensors
722+
if isinstance(input_data, Sequence):
714723
input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data])
715724
# Try to convert the input to a tensor
716725
else:

devtools/inspector/numerical_comparator/TARGETS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ python_library(
1414
srcs = ["l1_numerical_comparator.py"],
1515
deps = [
1616
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
17-
"//executorch/devtools/inspector:lib",
17+
"//executorch/devtools/inspector:inspector_utils",
1818
],
1919
)
2020

@@ -23,7 +23,7 @@ python_library(
2323
srcs = ["mse_numerical_comparator.py"],
2424
deps = [
2525
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
26-
"//executorch/devtools/inspector:lib",
26+
"//executorch/devtools/inspector:inspector_utils",
2727
],
2828
)
2929

devtools/inspector/tests/inspector_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from unittest.mock import patch
1919

20+
import pandas as pd
21+
2022
import torch
2123
import torch.fx
2224

@@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self):
578580
self.assertIn((key,), runtime_outputs)
579581
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
580582

583+
def test_calculate_numeric_gap(self):
584+
# Create a context manager to patch functions called by Inspector.__init__
585+
with patch.object(
586+
_inspector, "parse_etrecord", return_value=None
587+
), patch.object(
588+
_inspector, "gen_etdump_object", return_value=None
589+
), patch.object(
590+
EventBlock, "_gen_from_etdump"
591+
), patch.object(
592+
_inspector, "gen_graphs_from_etrecord"
593+
):
594+
# Call the constructor of Inspector
595+
inspector_instance = Inspector(
596+
etdump_path=ETDUMP_PATH,
597+
etrecord=ETRECORD_PATH,
598+
)
599+
600+
aot_intermediate_outputs = {
601+
(0,): torch.tensor([1.0, 2.0, 3.0]),
602+
(1,): torch.tensor([4.0, 5.0, 6.0]),
603+
}
604+
605+
runtime_intermediate_outputs = {
606+
(0,): torch.tensor([2.0, 1.0, 4.0]),
607+
(1,): torch.tensor([3.0, 6.0, 5.0]),
608+
}
609+
610+
inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs
611+
inspector_instance._get_runtime_intermediate_outputs = (
612+
lambda: runtime_intermediate_outputs
613+
)
614+
615+
df = inspector_instance.calculate_numeric_gap(distance="L1")
616+
self.assertIsInstance(df, pd.DataFrame)
617+
self.assertEqual(len(df), 2)
618+
cols = set(df.columns)
619+
expected_cols = {
620+
"aot_debug_handle",
621+
"aot_intermediate_output",
622+
"runtime_debug_handle",
623+
"runtime_intermediate_output",
624+
"gap",
625+
}
626+
self.assertEqual(cols, expected_cols)
627+
founded_aot_debug_handle = set(df["aot_debug_handle"])
628+
self.assertEqual(
629+
founded_aot_debug_handle, set(aot_intermediate_outputs.keys())
630+
)
631+
for _, row in df.iterrows():
632+
aot_debuh_handle = row["aot_debug_handle"]
633+
# aot_intermediate_output should equal aot_intermediate_outputs[h]
634+
self.assertTrue(
635+
torch.allclose(
636+
row["aot_intermediate_output"],
637+
aot_intermediate_outputs[aot_debuh_handle],
638+
)
639+
)
640+
# runtime_debug_hanlde equals aot_debug_handle at this case
641+
self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle)
642+
# runtime_intermediate_output should equal runtime_intermediate_outputs[h]
643+
self.assertTrue(
644+
torch.allclose(
645+
row["runtime_intermediate_output"],
646+
runtime_intermediate_outputs[aot_debuh_handle],
647+
)
648+
)
649+
# gap should equal 3.0
650+
self.assertEqual(row["gap"], 3.0)
651+
581652
def _gen_random_float_list(self) -> List[float]:
582653
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
583654

0 commit comments

Comments
 (0)