Skip to content

Commit ad5a00a

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Remove no-op clones (WIP) (pytorch#15838)
Summary: Pull Request resolved: pytorch#15838 Differential Revision: D86588171
1 parent f4a9268 commit ad5a00a

File tree

3 files changed

+59
-55
lines changed

3 files changed

+59
-55
lines changed

backends/transforms/test/test_remove_clone_ops.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,15 @@ def test_clone_non_identity_survives(self):
148148
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
149149

150150
exported = export(model.eval(), (x,), strict=True)
151+
print(exported)
151152
before_epm = to_edge(
152153
exported,
153154
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
154155
)
156+
print(before_epm.exported_program())
155157

156158
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
159+
print(updated_epm.exported_program())
157160

158161
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
159162
updated_epm.exported_program().graph_module.code
@@ -164,34 +167,6 @@ def test_clone_non_identity_survives(self):
164167
assert torch.allclose(actual, expected)
165168
assert is_channel_last_dim_order(actual)
166169

167-
def test_clone_identity_removed(self):
168-
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""
169-
170-
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
171-
model = SimpleCloneChannelsLastModule()
172-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
173-
174-
exported = export(model.eval(), (x,), strict=True)
175-
before_epm = to_edge(
176-
exported,
177-
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
178-
)
179-
180-
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
181-
before_epm.exported_program().graph_module.code
182-
)
183-
184-
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
185-
186-
FileCheck().check_not(clone_op_str).run(
187-
updated_epm.exported_program().graph_module.code
188-
)
189-
190-
expected = before_epm.exported_program().module()(x)
191-
actual = updated_epm.exported_program().module()(x)
192-
assert torch.allclose(actual, expected)
193-
assert is_channel_last_dim_order(actual)
194-
195170

196171
if __name__ == "__main__":
197172
unittest.main()

exir/passes/remove_noop_pass.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class RemoveNoopPass(ExportPass):
4747
"""
4848
Removes noops that pass through arguments.
4949
"""
50-
50+
5151
def call(self, graph_module: GraphModule) -> PassResult:
5252

5353
# In this list we'll collect all the dequant nodes that are inputs to ops that
@@ -56,35 +56,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
5656
dequant_nodes = []
5757

5858
for node in graph_module.graph.nodes:
59-
if node.op != "call_function":
60-
continue
61-
62-
if node.target not in (
63-
torch.ops.aten.to.dtype,
64-
torch.ops.aten.dropout.default,
65-
torch.ops.aten.slice_copy.Tensor,
66-
):
67-
continue
68-
69-
orig_tensor = node.args[0].meta["val"]
70-
71-
if orig_tensor is node.meta["val"]:
72-
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
73-
# Otherwise, removing only the op will suffice.
59+
if RemoveNoopPass._should_remove_node(node):
7460
if node.args[0].target in _DEQUANT_OPS:
7561
dequant_nodes += [node.args[0]]
7662
node.replace_all_uses_with(node.args[0])
77-
continue
78-
79-
if node.target == torch.ops.aten.slice_copy.Tensor:
80-
# Only do this check if all the dims are static.
81-
if all(isinstance(dim, int) for dim in orig_tensor.size()):
82-
if orig_tensor.shape == node.meta["val"].shape:
83-
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
84-
# Otherwise, removing only the op will suffice.
85-
if node.args[0].target in _DEQUANT_OPS:
86-
dequant_nodes += [node.args[0]]
87-
node.replace_all_uses_with(node.args[0])
8863

8964
graph_module.graph.eliminate_dead_code()
9065
eliminate_dq_q(graph_module, dequant_nodes)
@@ -93,6 +68,29 @@ def call(self, graph_module: GraphModule) -> PassResult:
9368

9469
return PassResult(graph_module, True)
9570

71+
@staticmethod
72+
def _should_remove_node(node: torch.fx.Node) -> bool:
73+
if node.op != "call_function":
74+
return False
75+
76+
input_meta_val = node.args[0].meta.get("val", None) if len(node.args) > 0 and hasattr(node.args[0], "meta") else None
77+
78+
if input_meta_val is not None:
79+
if node.target in (
80+
torch.ops.aten.to.dtype,
81+
torch.ops.aten.dropout.default,
82+
):
83+
return input_meta_val is node.meta["val"]
84+
elif node.target == torch.ops.aten.slice_copy.Tensor:
85+
# Only do this check if all the dims are static.
86+
return all(isinstance(dim, int) for dim in input_meta_val.size()) and input_meta_val.shape == node.meta["val"].shape
87+
elif node.target == torch.ops.aten.clone.default:
88+
# Remove if memory_format=None, preserve_format, or input already has the target memory format.
89+
dest_memory_format = node.kwargs.get("memory_format", None) or torch.preserve_format
90+
return dest_memory_format == torch.preserve_format or input_meta_val.is_contiguous(memory_format=dest_memory_format)
91+
92+
return False
93+
9694

9795
class RemoveToCopyPass(ExportPass):
9896
"""

exir/tests/test_passes.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,3 +2093,34 @@ def forward(self, x):
20932093
prop_tensor.is_contiguous(),
20942094
f"Propagated tensor is not contiguous: {prop_tensor.stride()}",
20952095
)
2096+
2097+
def test_remove_noop_pass_clone(self) -> None:
2098+
"""
2099+
Verify the no-op clones are removed from the graph.
2100+
"""
2101+
2102+
class CloneModel(torch.nn.Module):
2103+
def forward(self, x):
2104+
return x.clone() + x.clone()
2105+
2106+
model = CloneModel()
2107+
inputs = (torch.randn(1, 16),)
2108+
2109+
ep = torch.export.export(model, inputs)
2110+
lowered = to_edge_transform_and_lower(ep)
2111+
2112+
# Sanity check the test - we should see clones in the exported program
2113+
self.assertTrue(
2114+
any(
2115+
n.op == "call_function" and n.target == torch.ops.aten.clone.default
2116+
for n in ep.graph.nodes
2117+
)
2118+
)
2119+
2120+
# Since the clone ops are no-ops, they should be gone.
2121+
self.assertFalse(
2122+
any(
2123+
n.op == "call_function" and n.target == exir_ops.edge.dim_order_ops._clone_dim_order.default
2124+
for n in lowered.exported_program().graph.nodes
2125+
)
2126+
)

0 commit comments

Comments
 (0)