Skip to content

Commit 3124a6b

Browse files
authored
Fix XNNPACK handling of negative permute dims (#14169)
### Summary Fix handling of negative dims in permute ops on XNNPACK. They currently fail to lower. This bug was surfaced by the backend operator test suite. ### Test plan I've added a test to cover negative dims in permute.
1 parent 8a0a25b commit 3124a6b

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

backends/xnnpack/operators/op_permute.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,21 @@ def define_node(
4444
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
4545

4646
# input
47-
input_id = vals_to_ids[get_input_node(node, 0)]
47+
input_node = get_input_node(node, 0)
48+
input_id = vals_to_ids[input_node]
4849

4950
# output
5051
output_id = vals_to_ids[node]
5152

5253
# permutation
54+
input_rank = input_node.meta["val"].dim()
5355
permute_order = cast(List[int], node.args[1])
5456

57+
# Handle negative dimensions by converting them to positive indices
58+
permute_order = [
59+
(dim + input_rank) if dim < 0 else dim for dim in permute_order
60+
]
61+
5562
# change permute order if under channels last
5663
is_channels_last = node.meta.get(
5764
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False

backends/xnnpack/test/ops/test_permute.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ def test_fp32_permute(self):
5555
inputs = (torch.randn(1, 1, 4, 4),)
5656
self._test_permute(inputs)
5757

58+
def test_fp32_permute_negative_dim(self):
59+
inputs = (torch.randn(1, 1, 4, 4),)
60+
(
61+
Tester(self.Permute([0, -2, -1, 1]), inputs)
62+
.export()
63+
.check_count({"torch.ops.aten.permute.default": 1})
64+
.to_edge_transform_and_lower()
65+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
66+
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"])
67+
.to_executorch()
68+
.serialize()
69+
.run_method_and_compare_outputs()
70+
)
71+
5872
def test_fp32_permute_copy(self):
5973
inputs = (torch.randn(1, 1, 4, 4),)
6074
(

0 commit comments

Comments
 (0)