Skip to content

XNN AvgPool2d fails to lower with single element kernel_size #10968

@GregoryComer

Description

@GregoryComer

🐛 Describe the bug

When attempting to lower an AvgPool2d module with a single-element kernel size list, XNN's partitioning logic will error out due to assuming that if kernel_size is a list, it has two elements. In the case where the is contains one element, it should use it for both height and width. The partition does this for scalar values, but not single-element lists.

Repro:

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = torch.nn.AvgPool2d([3])
    
    def forward(self, x):
        return self.pool(x)

inputs = (
    torch.randn(1, 3, 16, 16),
)

ep = torch.export.export(Model(), inputs)
print(ep)
et_program = to_edge_transform_and_lower(
    ep,  
    partitioner=[XnnpackPartitioner()]
).to_executorch()

Output:

File ~/miniconda3/envs/pytorch/lib/python3.11/site-packages/executorch/backends/xnnpack/partition/config/generic_node_configs.py:146, in AvgPoolingConfig.check_constraints(self, node, ep)
    143     count_include_pad = cast(bool, args[5])
    145 kernel_size = cast(List[int], args[1])
--> 146 pooling_region = kernel_size[0] * kernel_size[1]
    147 divisor_override = pooling_region  # Default divisor is pooling_region
    148 if len(args) >= 7:

IndexError: list index out of range

Versions

N/A

cc @digantdesai @mcr229 @cbilgin

Metadata

Metadata

Assignees

No one assigned

    Labels

    backend testerThis bug was found by the backend test suite.module: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions