Skip to content

Commit c7a9671

Browse files
authored
Do not clear UncachedCompile between graph executions (#9282)
Co-authored-by: Haifeng Jin <[email protected]>
1 parent 01072eb commit c7a9671

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

test/dynamo/test_bridge.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,20 @@ def foo(t):
303303
t = torch.randint(0, 5, (3,), device=device)
304304
self._compile_and_check(foo, (t,))
305305

306+
def test_metrics_preserved(self):
307+
metrics.clear_counters()
308+
torch_xla._XLAC._xla_increment_counter('UncachedCompile', 3)
309+
torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 4)
310+
311+
model = BasicModule()
312+
graph_module = torch.fx.symbolic_trace(model)
313+
collector = bridge.UnsupportedNodesCollector(graph_module)
314+
collector.run(torch.randn(1), torch.randn(1))
315+
316+
self.assertEqual(metrics.counter_value('UncachedCompile'), 3)
317+
self.assertEqual(metrics.counter_value('DynamoExtractCompiledGraph'), 4)
318+
metrics.clear_counters()
319+
306320

307321
if __name__ == "__main__":
308322
from torch._dynamo.test_case import run_tests

torch_xla/_dynamo/dynamo_bridge.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ def run_node(self, n: torch.fx.Node):
646646
# We need to restore this metric count later, so save it in a separate variable
647647
dynamo_extract_graph_helper_metric_count = metrics.counter_value(
648648
'DynamoExtractCompiledGraph')
649+
uncached_compile_metric_count = metrics.counter_value('UncachedCompile')
649650

650651
metrics.clear_counters()
651652
result = super().run_node(n)
@@ -684,6 +685,9 @@ def all_tensors_on_xla_device(value):
684685
# Restore this metric counter
685686
torch_xla._XLAC._xla_increment_counter(
686687
'DynamoExtractCompiledGraph', dynamo_extract_graph_helper_metric_count)
688+
if uncached_compile_metric_count is not None:
689+
torch_xla._XLAC._xla_increment_counter('UncachedCompile',
690+
uncached_compile_metric_count)
687691

688692
return result
689693

0 commit comments

Comments
 (0)