Skip to content

Perhaps a bug with op aten.split_with_sizes? #7430

Open
@Hosh1ro

Description

@Hosh1ro

🐛 Describe the bug

I wrote a custom model like this:

class TorchTest(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, dropout=0.1)

    def forward(self, x, pos_embed) -> Tensor:
        q = k = x + pos_embed
        x = self.attn(q, k, value=x)
        return x

Then I tried exporting it to QNN backend using following code:

example_inputs = (torch.ones((1, 1, 512)), torch.ones((1, 1, 512)))

pte_filename = "test"
instance = TorchTest()

build_executorch_binary(
    instance.eval(),
    example_inputs,
    args.model,
    f"{args.artifact}/{pte_filename}",
    None,
)

But I got errors:

...
  File "/home/user/Projects/android/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 77, in is_node_supported
    op_wrapper = self.node_visitors[node.target.__name__].define_node(
  File "/home/user/Projects/android/executorch/backends/qualcomm/builders/op_linear.py", line 62, in define_node
    weight_tensor_wrapper = self.define_tensor(
  File "/home/user/Projects/android/executorch/backends/qualcomm/builders/node_visitor.py", line 319, in define_tensor
    dims = [1] if len(tensor.size()) == 0 else tensor.size()
AttributeError: 'NoneType' object has no attribute 'size'

After some "print()", I found that nodes like this were generated:

9, aten_split_with_sizes_default : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': [FakeTensor(..., size=(512, 512), grad_fn=<SplitWithSizesBackward0>), FakeTensor(..., size=(512, 512), grad_fn=<SplitWithSizesBackward0>), FakeTensor(..., size=(512, 512), grad_fn=<SplitWithSizesBackward0>)], 'tensor_meta': [None, None, None], 'debug_handle': 8}
10, getitem : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': FakeTensor(..., size=(512, 512), grad_fn=<SplitWithSizesBackward0>), 'tensor_meta': None, 'debug_handle': 9}
11, getitem_1 : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': FakeTensor(..., size=(512, 512), grad_fn=<SplitWithSizesBackward0>), 'tensor_meta': None, 'debug_handle': 10}
12, getitem_2 : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': FakeTensor(..., size=(512, 512), grad_fn=<SplitWithSizesBackward0>), 'tensor_meta': None, 'debug_handle': 11}
13, aten_split_with_sizes_default_1 : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split_1', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': [FakeTensor(..., size=(512,), grad_fn=<SplitWithSizesBackward0>), FakeTensor(..., size=(512,), grad_fn=<SplitWithSizesBackward0>), FakeTensor(..., size=(512,), grad_fn=<SplitWithSizesBackward0>)], 'tensor_meta': [None, None, None], 'debug_handle': 12}
14, getitem_3 : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split_1', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': FakeTensor(..., size=(512,), grad_fn=<SplitWithSizesBackward0>), 'tensor_meta': None, 'debug_handle': 13}
15, getitem_4 : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split_1', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': FakeTensor(..., size=(512,), grad_fn=<SplitWithSizesBackward0>), 'tensor_meta': None, 'debug_handle': 14}
16, getitem_5 : {'stack_trace': '  File "/home/user/Projects/android/executorch/examples/qualcomm/scripts/act.py", line 78, in forward\n    x = self.attn(q, k, value=x)\n', 'nn_module_stack': {'L__self__': ('', '__main__.TorchTest'), 'L__self___attn': ('attn', 'torch.nn.modules.activation.MultiheadAttention')}, 'torch_fn': ('multi_head_attention_forward_1', 'function.multi_head_attention_forward'), 'source_fn_stack': [('l__self___attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)], 'original_aten': <OpOverload(op='aten.split', overload='Tensor')>, 'from_node': [('l__self___attn', 'L__self___attn'), ('split_1', <OpOverload(op='aten.split', overload='Tensor')>)], 'seq_nr': 147, 'val': FakeTensor(..., size=(512,), grad_fn=<SplitWithSizesBackward0>), 'tensor_meta': None, 'debug_handle': 15}

And the 'NoneType' error is caused by node getitem_2, in file op_linear.py, weight_tensor is None (and bias_tensor is None for node getitem_5).
I tried setting the value of these 'None' tensors to node.meta["val"](maybe it's not correct), and following errors appeared:

...
  File "/home/user/Projects/android/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 77, in is_node_supported
    op_wrapper = self.node_visitors[node.target.__name__].define_node(
  File "/home/user/Projects/android/executorch/backends/qualcomm/builders/op_split_with_sizes.py", line 68, in define_node
    dim = cast(int, node.args[2])
IndexError: tuple index out of range

The node is aten_split_with_sizes_default_1, and the args of it is (p_attn_in_proj_bias, [512, 512, 512]), so it does indeed trigger "index out of range".
I think maybe it is a bug related to this specific op. I also tried the same model code with mps backend, using command like

python3 -m examples.apple.mps.scripts.mps_example --model_name="test" --bundled

The Edge IR Graph printed is:

graph():
...
    %aten_split_with_sizes_copy_default : [num_users=3] = call_function[target=executorch.exir.dialects.edge._ops.aten.split_with_sizes_copy.default](args = (%p_attn_in_proj_weight, [512, 512, 512]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%aten_split_with_sizes_copy_default, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%aten_split_with_sizes_copy_default, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%aten_split_with_sizes_copy_default, 2), kwargs = {})
    %aten_split_with_sizes_copy_default_1 : [num_users=3] = call_function[target=executorch.exir.dialects.edge._ops.aten.split_with_sizes_copy.default](args = (%p_attn_in_proj_bias, [512, 512, 512]), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%aten_split_with_sizes_copy_default_1, 0), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%aten_split_with_sizes_copy_default_1, 1), kwargs = {})
    %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%aten_split_with_sizes_copy_default_1, 2), kwargs = {})
...

It gave me a similar "tuple index out of range" error like QNN backend with node name "aten_split_with_sizes_copy_default".

Versions

Python: 3.10.16
ExecuTorch: 0.4
QNN SDK: 2.29.0.241129

cc @cccclai @winskuo-quic @shewu-quic

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: mpsIssues related to Apple's MPS delegation and code under backends/apple/mps/partner: qualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, QualcommtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions