Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import operator
from collections import deque
from numbers import Number
from typing import Any, Callable, cast
from typing import Any, Callable, cast, override

# Import these for the cadence function signatures.
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
Expand All @@ -39,6 +39,9 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import Argument
from torch.fx.node import Argument

Check warning on line 43 in backends/cadence/aot/fuse_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F811

redefinition of unused 'Argument' from line 42 See https://www.flake8rules.com/rules/F811.html.
from torch.fx.passes.dialect.common.cse_pass import CSEPass
from torch.nn.utils.fusion import fuse_conv_bn_weights


Expand Down Expand Up @@ -1154,6 +1157,24 @@
return True


class HierarchicalCSEPass(ExportPass):
"""
A hierarchical Common Subexpression Elimination (CSE) pass that recursively
processes all submodules in a GraphModule hierarchy.

This pass applies CSE to the main graph and all nested subgraphs, ensuring
that redundant computations are eliminated at all levels of the module hierarchy.
"""

@override
def call_submodule(
self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...]
) -> PassResult:
graph_module = CSEPass().call(graph_module).graph_module
result = super().call_submodule(graph_module, inputs)
return PassResult(result.graph_module, True)


class CadenceFuseOpsInGraph:
passes = [
FuseMMWithAdd,
Expand Down
256 changes: 256 additions & 0 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -23,6 +23,7 @@
FuseMulTensorIntoQuantPass,
FuseQuantDequantToRequantizePass,
FuseTransposeOrPermuteOpPairsPass,
HierarchicalCSEPass,
)
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
Expand Down Expand Up @@ -1123,3 +1124,258 @@
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: num_forks,
},
)


class TestHierarchicalCSEPass(TestFusionPassesBase):
"""Tests for HierarchicalCSEPass that performs CSE across all submodules.

The HierarchicalCSEPass eliminates redundant computations (common subexpressions)
at all levels of the module hierarchy, including nested subgraphs.
"""

# -------------------------------------------------------------------------
# Graph Creation Utilities
# -------------------------------------------------------------------------

def _create_duplicate_add_scalar_graph(
self, shape: tuple[int, ...] = (8, 8)
) -> torch.fx.GraphModule:
"""Create a graph with two identical add.Scalar operations.

Graph structure:
x (placeholder)
├── add.Scalar(x, 1) ─┐
└── add.Scalar(x, 1) ─┴── add.Tensor (result)
"""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape))
add1 = builder.call_operator(exir_ops.edge.aten.add.Scalar, (x, 1))
add2 = builder.call_operator(exir_ops.edge.aten.add.Scalar, (x, 1))
result = builder.call_operator(exir_ops.edge.aten.add.Tensor, (add1, add2))
builder.output([result])
return builder.get_graph_module()

def _create_different_add_scalar_graph(
self, shape: tuple[int, ...] = (8, 8)
) -> torch.fx.GraphModule:
"""Create a graph with add.Scalar operations using different values.

Graph structure:
x (placeholder)
├── add.Scalar(x, 1) ─┐
├── add.Scalar(x, 2) ─┼── add.Tensor chain (result)
└── add.Scalar(x, 3) ─┘
"""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape))
add1 = builder.call_operator(exir_ops.edge.aten.add.Scalar, (x, 1))
add2 = builder.call_operator(exir_ops.edge.aten.add.Scalar, (x, 2))
add3 = builder.call_operator(exir_ops.edge.aten.add.Scalar, (x, 3))
temp = builder.call_operator(exir_ops.edge.aten.add.Tensor, (add1, add2))
result = builder.call_operator(exir_ops.edge.aten.add.Tensor, (temp, add3))
builder.output([result])
return builder.get_graph_module()

def _create_diamond_pattern_graph(
self, shape: tuple[int, ...] = (32, 64)
) -> torch.fx.GraphModule:
"""Create a diamond-shaped graph with duplicate and unique operations.

Graph structure:
x (placeholder)
├── add.Scalar(x, 5) ─── mul.Scalar(_, 2) ─┐
└── add.Scalar(x, 5) ─── mul.Scalar(_, 3) ─┴── add.Tensor (result)
"""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape))
add_branch1 = builder.call_operator(exir_ops.edge.aten.add.Scalar, (x, 5))
add_branch2 = builder.call_operator(exir_ops.edge.aten.add.Scalar, (x, 5))
mul1 = builder.call_operator(exir_ops.edge.aten.mul.Scalar, (add_branch1, 2))
mul2 = builder.call_operator(exir_ops.edge.aten.mul.Scalar, (add_branch2, 3))
result = builder.call_operator(exir_ops.edge.aten.add.Tensor, (mul1, mul2))
builder.output([result])
return builder.get_graph_module()

def _create_map_body_with_duplicate_ops(
self, sample_inp: torch.Tensor
) -> torch.fx.GraphModule:
"""Create a map function body with duplicate add.Scalar operations."""
builder = GraphBuilder()
x = builder.placeholder("x", sample_inp)
add1 = builder.call_operator(torch.ops.aten.add.Scalar, (x, 1))
add2 = builder.call_operator(torch.ops.aten.add.Scalar, (x, 1))
result = builder.call_operator(torch.ops.aten.add.Tensor, (add1, add2))
builder.output([result])
return builder.get_graph_module()

def _create_map_body_with_mixed_ops(
self, sample_inp: torch.Tensor
) -> torch.fx.GraphModule:
"""Create a map function body with duplicate adds and different muls."""
builder = GraphBuilder()
x = builder.placeholder("x", sample_inp)
add1 = builder.call_operator(torch.ops.aten.add.Scalar, (x, 1))
add2 = builder.call_operator(torch.ops.aten.add.Scalar, (x, 1))
mul1 = builder.call_operator(torch.ops.aten.mul.Scalar, (add1, 2))
mul2 = builder.call_operator(torch.ops.aten.mul.Scalar, (add2, 3))
result = builder.call_operator(torch.ops.aten.add.Tensor, (mul1, mul2))
builder.output([result])
return builder.get_graph_module()

def _create_map_impl_graph(
self,
map_body: torch.fx.GraphModule,
batch_size: int = 4,
feature_size: int = 8,
) -> torch.fx.GraphModule:
"""Wrap a map body function in a map_impl graph."""
inp = torch.randn(batch_size, feature_size)
builder = GraphBuilder()
inp_proxy = builder.placeholder("inp", inp)
map_result = builder.call_operator(
torch.ops.higher_order.map_impl, (map_body, (inp_proxy,), ())
)
map_getitem = builder.call_getitem(map_result, 0)
builder.output([map_getitem])
return builder.get_graph_module()

def _get_map_body(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Extract the map body submodule from a graph containing map_impl."""
map_nodes = gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.map_impl
)
self.assertEqual(len(map_nodes), 1, "Should have exactly one map_impl node")
map_body_getattr = map_nodes[0].args[0]
self.assertTrue(hasattr(gm, map_body_getattr.target))
map_body = getattr(gm, map_body_getattr.target)
self.assertIsInstance(map_body, torch.fx.GraphModule)
return cast(torch.fx.GraphModule, map_body)

def _apply_cse_pass(
self, gm: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""Apply HierarchicalCSEPass and return the resulting graph module."""
p = HierarchicalCSEPass()
return cast(PassResult, p(gm)).graph_module

# -------------------------------------------------------------------------
# Test Cases
# -------------------------------------------------------------------------

def test_cse_removes_duplicate_add_scalar(self) -> None:
"""Test that CSE removes duplicate add.Scalar operations with same input."""
gm = self._create_duplicate_add_scalar_graph()

self.assertEqual(
count_node(gm, exir_ops.edge.aten.add.Scalar),
2,
"Should have 2 duplicate add.Scalar before CSE",
)

gm_after = self._apply_cse_pass(gm)

self.assertEqual(
count_node(gm_after, exir_ops.edge.aten.add.Scalar),
1,
"CSE should have eliminated duplicate add.Scalar operation",
)

def test_cse_with_map_impl_duplicate_ops(self) -> None:
"""Test CSE on a program with map_impl containing duplicate operations."""
sample_inp = torch.randn(8)
map_body = self._create_map_body_with_duplicate_ops(sample_inp)
gm = self._create_map_impl_graph(map_body)

# Verify before CSE
map_body_before = self._get_map_body(gm)
self.assertEqual(
count_node(map_body_before, torch.ops.aten.add.Scalar),
2,
"Map body should have 2 duplicate add.Scalar ops before CSE",
)

# Apply CSE
gm_after = self._apply_cse_pass(gm)

# Verify after CSE
map_body_after = self._get_map_body(gm_after)
self.assertEqual(
count_node(map_body_after, torch.ops.aten.add.Scalar),
1,
"CSE should have eliminated duplicate add.Scalar in map body",
)

def test_cse_with_map_impl_mixed_duplicate_and_unique_ops(self) -> None:
"""Test CSE on map_impl with both duplicate and unique operations."""
sample_inp = torch.randn(8)
map_body = self._create_map_body_with_mixed_ops(sample_inp)
gm = self._create_map_impl_graph(map_body)

# Verify before CSE
map_body_before = self._get_map_body(gm)
self.assertEqual(
count_node(map_body_before, torch.ops.aten.add.Scalar),
2,
"Should have 2 duplicate add.Scalar before CSE",
)
self.assertEqual(
count_node(map_body_before, torch.ops.aten.mul.Scalar),
2,
"Should have 2 different mul.Scalar before CSE",
)

# Apply CSE
gm_after = self._apply_cse_pass(gm)

# Verify after CSE
map_body_after = self._get_map_body(gm_after)
self.assertEqual(
count_node(map_body_after, torch.ops.aten.add.Scalar),
1,
"CSE should have merged duplicate add.Scalar to 1",
)
self.assertEqual(
count_node(map_body_after, torch.ops.aten.mul.Scalar),
2,
"CSE should NOT merge different mul.Scalar operations",
)

def test_cse_preserves_different_operations(self) -> None:
"""Test that CSE does not eliminate operations with different arguments."""
gm = self._create_different_add_scalar_graph()

self.assertEqual(
count_node(gm, exir_ops.edge.aten.add.Scalar),
3,
"Should have 3 different add.Scalar before CSE",
)

gm_after = self._apply_cse_pass(gm)

self.assertEqual(
count_node(gm_after, exir_ops.edge.aten.add.Scalar),
3,
"CSE should NOT eliminate add.Scalar ops with different scalar values",
)

def test_cse_diamond_pattern(self) -> None:
"""Test CSE on diamond-shaped graph where ops share inputs."""
gm = self._create_diamond_pattern_graph()

self.check_op_counts(
gm,
expected_op_counts={
exir_ops.edge.aten.add.Scalar: 2,
exir_ops.edge.aten.mul.Scalar: 2,
},
)

gm_after = self._apply_cse_pass(gm)

self.check_op_counts(
gm_after,
expected_op_counts={
exir_ops.edge.aten.add.Scalar: 1, # Merged to one
exir_ops.edge.aten.mul.Scalar: 2, # Still two (different args)
},
)
Loading