Skip to content

Commit 4c727ee

Browse files
authored
Merge branch 'main' into vela_400
2 parents 9653a65 + a89b858 commit 4c727ee

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2571
-1788
lines changed

.github/workflows/pull.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,13 @@ jobs:
971971
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
972972
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
973973
974+
# "Classic" Operator tests
975+
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build
976+
# TODO(ssjia): figure out how to run custom op tests in CI. Currently, they are
977+
# failing due to to the libstdc++.so.6 installed with conda not supporting
978+
# GLIBCXX_3.4.30. These tests are still run in Meta internal CI.
979+
# ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test
980+
974981
# Run e2e testing for selected operators. More operators will be tested via this
975982
# route in the future.
976983
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"

backends/arm/README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,6 @@ The current TOSA version does not support int64. However, int64 is commonly used
206206
- For quantized models, these transformations will be automatically handled during annotation before the export stage.
207207

208208
List of model specific and optional passes:
209-
- InsertCastForOpsWithInt64InputPass
210-
- Functionality:
211-
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
212-
- Supported Ops:
213-
- aten.embedding.default, aten.slice_copy.Tensor
214-
- Example usage:
215-
- backends/arm/test/models/test_llama.py
216-
217209
- ConvertInt64ConstOpsToInt32Pass
218210
- Functionalities:
219211
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
@@ -244,3 +236,16 @@ List of model specific and optional passes:
244236
- Example usage:
245237
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246238
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
239+
240+
- InsertInt32CastsAfterInt64PlaceholdersPass
241+
- Functionalities:
242+
- Inserts an int64 -> int32 cast immediately after each int64 placeholder (graph input).
243+
- Redirects all uses of each int64 placeholder to its int32 cast output.
244+
- Inserts local int32 -> int64 casts at call sites where an operator requires int64 inputs, e.g. `torch.nn.functional.one_hot`
245+
- Pass ordering:
246+
- When used with `ConvertInt64ConstOpsToInt32Pass` and `ConvertInt64OutputOpsToInt32Pass`, run this pass last.
247+
- Rationale: Those passes may cause retracing to re-infer some int64 placeholders as int32. Running this pass last casts only inputs that remain int64, minimizing inserted casts.
248+
- Example usage:
249+
- backends/arm/test/models/test_llama.py
250+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
251+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@
7575
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
7676
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
7777
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
78-
from .insert_int64_input_cast_pass import ( # noqa # noqa
79-
InsertCastForOpsWithInt64InputPass,
78+
from .insert_int32_casts_after_int64_placeholders import ( # noqa
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
)
8181
from .insert_rescales_pass import InsertRescalePass # noqa
8282
from .insert_table_ops import InsertTableOpsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
FuseConstantArgsPass,
7777
FuseEqualPlaceholdersPass,
7878
FuseQuantizedActivationPass,
79-
InsertCastForOpsWithInt64InputPass,
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
InsertRescalePass,
8181
InsertTableOpsPass,
8282
MatchArgDtypePass,
@@ -277,7 +277,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
277277
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
278278
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
279279
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
280-
self.add_pass(InsertCastForOpsWithInt64InputPass())
280+
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
281281
self.add_pass(DecomposeEmbeddingPass())
282282
self.add_pass(DecomposeScaledDotProductAttention())
283283
self.add_pass(DecomposeRoundPass())
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult
15+
from torch._subclasses.fake_tensor import FakeTensor
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass):
22+
"""
23+
Insert an int64->int32 cast after each int64 placeholder.
24+
25+
Note: Overflow checks are not applied in this pass. It is the user's responsibility to ensure that values fit within
26+
the int32 range.
27+
"""
28+
29+
# Ops that require i64 inputs → positions of args to upcast.
30+
# Key: op overload; Value: zero-based indices of positional args that must be i64.
31+
I64_INPUT_ARG_POSITIONS = {
32+
torch.ops.aten.one_hot.default: (0,),
33+
}
34+
35+
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
36+
"""
37+
If an operator requires int64 inputs but dtype propagation (via call_operator)
38+
produced int32, insert a local int32→int64 cast at the call site to satisfy
39+
PyTorch's operator input validation.
40+
"""
41+
modified = False
42+
graph = graph_module.graph
43+
for node in graph.nodes:
44+
if node.op != "call_function":
45+
continue
46+
if node.target not in self.I64_INPUT_ARG_POSITIONS:
47+
continue
48+
49+
with graph.inserting_before(node):
50+
arg_positions = self.I64_INPUT_ARG_POSITIONS.get(node.target)
51+
args_list = list(node.args)
52+
for pos in arg_positions: # type: ignore[union-attr]
53+
input_arg = args_list[pos]
54+
to_copy_op = self._get_decomposition(graph)
55+
cast_node = graph_module.graph.create_node(
56+
"call_function",
57+
to_copy_op,
58+
(input_arg,),
59+
{"dtype": torch.int64},
60+
)
61+
cast_node.meta["val"] = node.meta["val"].to(torch.int64)
62+
args_list[pos] = cast_node
63+
node.args = tuple(args_list)
64+
modified = True
65+
return modified
66+
67+
def _graph_uses_edge_ops(self, graph: torch.fx.Graph) -> bool:
68+
for n in graph.nodes:
69+
if n.op == "call_function":
70+
if isinstance(n.target, EdgeOpOverload):
71+
return True
72+
return False
73+
74+
def _get_decomposition(self, graph: torch.fx.Graph):
75+
if self._graph_uses_edge_ops(graph):
76+
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
77+
else:
78+
return torch.ops.dim_order_ops._to_dim_order_copy.default
79+
80+
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
81+
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
82+
83+
def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModule):
84+
modified = False
85+
graph = graph_module.graph
86+
for node in graph.nodes:
87+
if node.op != "placeholder":
88+
continue
89+
node_val = node.meta["val"]
90+
if not self._is_tensor_of_dtype(node_val, torch.int64):
91+
continue
92+
93+
to_copy_op = self._get_decomposition(graph)
94+
with graph.inserting_after(node):
95+
cast_after = create_node(
96+
graph,
97+
to_copy_op,
98+
args=(node,),
99+
kwargs={
100+
"dtype": torch.int32,
101+
},
102+
)
103+
users = [user for user in node.users if user != cast_after]
104+
for user in users:
105+
user.replace_input_with(node, cast_after)
106+
logger.warning(
107+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 placeholder"
108+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
109+
)
110+
modified = True
111+
return modified
112+
113+
def call(self, graph_module: torch.fx.GraphModule):
114+
modified = False
115+
modified |= self._insert_placeholder_i64_to_i32_casts(graph_module)
116+
modified |= self._insert_callsite_i32_to_i64_casts(graph_module)
117+
118+
if modified:
119+
graph_module.graph.eliminate_dead_code()
120+
graph_module.recompile()
121+
graph_module = super().call(graph_module).graph_module
122+
return PassResult(graph_module, modified)

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.backends.arm._passes import (
1212
ConvertInt64ConstOpsToInt32Pass,
1313
ConvertInt64OutputOpsToInt32Pass,
14-
InsertCastForOpsWithInt64InputPass,
14+
InsertInt32CastsAfterInt64PlaceholdersPass,
1515
)
1616

1717
from executorch.backends.arm.test import common
@@ -33,10 +33,9 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
3333
# for that is some assert ops are removed by passes in the
3434
# .to_executorch step, i.e. after Arm partitioner.
3535
ops_after_partitioner = {
36-
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
37-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
3836
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
39-
"torch.ops.higher_order.executorch_call_delegate": 1,
37+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
38+
"torch.ops.higher_order.executorch_call_delegate": 2,
4039
}
4140

4241
def _prepare_inputs(
@@ -71,9 +70,9 @@ def test_CLIPTextModelWithProjection_tosa_FP(self):
7170
example_inputs=text_encoder_model_inputs,
7271
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
7372
transform_passes=[
74-
InsertCastForOpsWithInt64InputPass(),
7573
ConvertInt64ConstOpsToInt32Pass(),
7674
ConvertInt64OutputOpsToInt32Pass(),
75+
InsertInt32CastsAfterInt64PlaceholdersPass(),
7776
],
7877
)
7978
.export()

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,22 @@ class TestSD3Transformer2DModel(unittest.TestCase):
2222
SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium
2323
"""
2424

25-
# Adjust nbr below as we increase op support. Note: most of the delegates
26-
# calls are directly consecutive to each other in the .pte. The reason
27-
# for that is some assert ops are removed by passes in the
28-
# .to_executorch step, i.e. after Arm partitioner.
29-
ops_after_partitioner = {
25+
# Adjust nbr below as we increase op support.
26+
ops_after_partitioner_FP = {
3027
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
3128
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1,
3229
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
3330
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
3431
"torch.ops.higher_order.executorch_call_delegate": 1,
3532
}
3633

34+
ops_after_partitioner_INT = {
35+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
36+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
37+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
38+
"torch.ops.higher_order.executorch_call_delegate": 2,
39+
}
40+
3741
def _prepare_inputs(
3842
self,
3943
batch_size=2,
@@ -102,7 +106,7 @@ def test_SD3Transformer2DModel_tosa_FP(self):
102106
)
103107
.export()
104108
.to_edge_transform_and_lower()
105-
.check_count(self.ops_after_partitioner)
109+
.check_count(self.ops_after_partitioner_FP)
106110
.to_executorch()
107111
.run_method_and_compare_outputs(
108112
inputs=sd35_transformer2D_model_inputs,
@@ -125,7 +129,7 @@ def test_SD3Transformer2DModel_tosa_INT(self):
125129
.quantize()
126130
.export()
127131
.to_edge_transform_and_lower()
128-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
132+
.check_count(self.ops_after_partitioner_INT)
129133
.to_executorch()
130134
.run_method_and_compare_outputs(
131135
inputs=sd35_transformer2D_model_inputs,

0 commit comments

Comments
 (0)