Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions backends/nxp/aten_passes/convert_unsqueeze_to_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult


class ConvertUnsqueezeToViewPass(PassBase):
"""Replace 'aten.unsqueeze.default' with 'aten.view.default'.

x x
│ │
┌─────────────▼─────────────┐ replace with ┌─────────────▼─────────────┐
│ aten.unsqueeze(x, dim) │ ──────────────► │ aten.view.default(x, S) │
└─────────────┬─────────────┘ └─────────────┬─────────────┘
│ │
▼ ▼
out out
"""

@staticmethod
def _is_unsqueeze(node_: Node) -> bool:
return (
node_.op == "call_function"
and node_.target == torch.ops.aten.unsqueeze.default
)

def _create_view_node(self, *view_args) -> Node:
view_target = torch.ops.aten.view.default
view_node = self.graph_module.graph.call_function(view_target, view_args)

view_node.meta["source_fn_stack"] = [
(view_node.name, torch.ops.aten.view.default)
]

x_val = view_args[0].meta["val"]
with FakeTensorMode() as mode:
fake_input = FakeTensor.from_tensor(
torch.empty(x_val.shape, dtype=x_val.dtype), mode
)
output_shape = view_target(fake_input, *view_args[1:]).shape
view_node.meta["val"] = FakeTensor.from_tensor(
torch.empty(output_shape, dtype=x_val.dtype), mode
)

return view_node

def call(self, graph_module: GraphModule) -> Optional[PassResult]:
self.graph_module = graph_module
made_changes = False

if not any(self._is_unsqueeze(n) for n in graph_module.graph.nodes):
return PassResult(graph_module, made_changes)

for node in list(graph_module.graph.nodes):
if not self._is_unsqueeze(node):
continue

input_node = node.all_input_nodes[0]
target_size = node.meta["val"].shape

with self.graph_module.graph.inserting_after(node):
view_node = self._create_view_node(input_node, target_size)

node.replace_all_uses_with(view_node)
self.graph_module.graph.erase_node(node)

made_changes = True

self.graph_module.recompile()
self.graph_module.graph.eliminate_dead_code()

return PassResult(graph_module, made_changes)
6 changes: 5 additions & 1 deletion backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 NXP
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -7,6 +7,9 @@

import torch

from executorch.backends.nxp.aten_passes.convert_unsqueeze_to_view import (
ConvertUnsqueezeToViewPass,
)
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import (
FuseBatchNormWithConvPass,
)
Expand Down Expand Up @@ -49,6 +52,7 @@ def __init__(
RemoveNodesWithKnownOutputs(),
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
ConvertUnsqueezeToViewPass(),
]

super().__init__(passes)
Expand Down
11 changes: 10 additions & 1 deletion backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024-2025 NXP
# Copyright (c) 2024-2026 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -670,3 +670,12 @@ def __init__(self):

def forward(self, x):
return self.sequential(x)


class UnsqueezeAddModel(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x, y):
return torch.unsqueeze(x + y, self.dim)
143 changes: 143 additions & 0 deletions backends/nxp/tests/test_convert_unsqueeze_to_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import pytest
import torch
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
ConvertUnsqueezeToViewPass,
NeutronAtenPassManager,
)
from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.tests.executorch_pipeline import (
neutron_target_spec,
to_quantized_edge_program,
)
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
graph_contains_any_of_ops,
)

from executorch.backends.nxp.tests.models import UnsqueezeAddModel
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram


@pytest.fixture(autouse=True)
def reseed_model_per_test_run():
torch.manual_seed(42)
np.random.seed(23)


@pytest.mark.parametrize(
"input_shape, dim",
[
pytest.param((2,), 0, id="1D."),
pytest.param((8, 4, 6), 2, id="3D."),
pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."),
pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."),
pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."),
],
)
def test_convert_unsqueeze_to_view_simple(mocker, input_shape, dim):
model = UnsqueezeAddModel(dim)

example_input_1 = torch.rand(input_shape)
example_input_2 = torch.rand(input_shape)

exir_program_aten = torch.export.export(
model,
(example_input_1, example_input_2),
).module()

# Check "aten.unsqueeze.default" is present
assert graph_contains_any_of_ops(
exir_program_aten.graph, [torch.ops.aten.unsqueeze.default]
)

example_input = (example_input_1, example_input_2)
outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)]

# Apply the optimization.
NeutronAtenPassManager(neutron_target_spec, [ConvertUnsqueezeToViewPass()])(
exir_program_aten
)

# Make sure no "aten.unsqueeze.default" is in the model.
assert not graph_contains_any_of_ops(
exir_program_aten.graph,
[torch.ops.aten.unsqueeze.default],
)

# Make sure there is "aten.view.default" in the model.
assert graph_contains_any_of_ops(
exir_program_aten.graph,
[torch.ops.aten.view.default],
)

outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)]

# Make sure the model still produces the exact same output.
assert len(outputs_before) == len(outputs_after)

for i in range(len(outputs_before)):
assert np.allclose(outputs_before[i], outputs_after[i])


@pytest.mark.parametrize(
"input_shape, dim",
[
pytest.param((2,), 0, id="1D."),
pytest.param((8, 4, 6), 2, id="3D."),
pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."),
pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."),
pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."),
],
)
def test_convert_unsqueeze_to_view_full_pipeline(mocker, input_shape, dim):
model = UnsqueezeAddModel(dim)
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

# Run conversion
edge_program = to_quantized_edge_program(
model,
[input_shape, input_shape],
).exported_program()

# Make sure no "aten.unsqueeze.default" is in the model.
assert not graph_contains_any_of_ops(
edge_program.graph,
[
torch.ops.aten.unsqueeze.default,
],
)

# Capture generated model
neutron_ir_model = converter_spy.spy_return[0]
exported_program: ExportedProgram = converter_spy.call_args.args[1]

# Make sure "edge.aten.view_copy.default" is in the model.
assert graph_contains_any_of_ops(
exported_program.graph,
[
exir_ops.edge.aten.view_copy.default,
],
)

example_input_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
np.int8
)
example_input_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
np.int8
)
example_input = {0: example_input_1, 1: example_input_2}

convert_run_compare(
exported_program,
input_data=example_input,
tfl_model=neutron_ir_model,
)
Loading