Skip to content

strided_copy operator in output graph when sample input has been transposed #16374

@etrommer

Description

@etrommer

🐛 Describe the bug

I occasionally read existing model calibration data from Numpy arrays that are in NHWC order when deploying with ExecuTorch. Whenever I do that and transpose the calibration data to NCHW, the output graph contains an as_strided_copy operator, even if I have previously called .contiguous() on the tensor.

Minimal working example to reproduce the behavior:

"""
Minimal working example of ExecuTorch PTE export functionality.
"""

import os
import torch
import torch.nn as nn
from torch.export import export
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig, ExecutorchBackendConfig
from executorch.extension.export_util.utils import save_pte_program
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import EthosUQuantizer, get_symmetric_quantization_config
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

# Constants
OUTPUT_DIR = "./output"
MODEL_NAME = "test_model"
ACCELERATOR_CONFIG = "ethos-u55-128"
SYSTEM_CONFIG = "Ethos_U55_High_End_Embedded"
MEMORY_MODE = "Shared_Sram"


class SimpleModel(nn.Module):
    """Simple CNN model for testing."""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(8, 10)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


def prepare_input_data(single_sample=False):
    """Prepare dummy input data in channels-first format."""
    dummy_data_nchw = torch.rand(100, 1, 49, 10)  # Dummy data for example
    
    if single_sample:
        dummy_data_nchw = dummy_data_nchw[0:1]
    
    return (dummy_data_nchw,)


def quantize_model(model):
    """Quantize model using EthosU quantizer."""
    # Create compile spec
    compile_spec = EthosUCompileSpec(
        ACCELERATOR_CONFIG,
        system_config=SYSTEM_CONFIG,
        memory_mode=MEMORY_MODE,
        extra_flags=["--verbose-operators", "--verbose-cycle-estimate"],
    )
    
    # Setup quantizer
    quantizer = EthosUQuantizer(compile_spec)
    operator_config = get_symmetric_quantization_config()
    quantizer.set_global(operator_config)
    
    model = torch.export.export(model, prepare_input_data(single_sample=True), strict=True).module()
    prepared_model = prepare_pt2e(model, quantizer)
    
    # Calibrate with dummy data
    calibration_inputs = prepare_input_data()[0]
    for x in calibration_inputs:
        prepared_model(x)
    
    # Convert to quantized model
    return convert_pt2e(prepared_model)


def lower_to_arm_backend(exported_program):
    """Apply ARM backend transformations."""
    compile_spec = EthosUCompileSpec(
        ACCELERATOR_CONFIG,
        system_config=SYSTEM_CONFIG,
        memory_mode=MEMORY_MODE,
        extra_flags=["--verbose-operators", "--verbose-cycle-estimate"],
    )
    
    partitioner = EthosUPartitioner(compile_spec)
    
    edge_program_manager = to_edge_transform_and_lower(
        exported_program,
        partitioner=[partitioner],
        compile_config=EdgeCompileConfig(_check_ir_validity=False),
    )

    return edge_program_manager


def export_pte_example():
    """Main export function - minimal working example."""
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    print("Creating model...")
    model = SimpleModel()
    
    print("Preparing input data...")
    tracing_inputs = prepare_input_data(single_sample=True)
    
    print("Quantizing model...")
    quantized_model = quantize_model(model)
    
    print("Exporting quantized model...")
    exported_program = export(quantized_model, tracing_inputs, strict=True)
    
    print("Lowering to ARM backend...")
    edge_program = lower_to_arm_backend(exported_program)
    
    print("Creating ExecutorTorch program...")
    exec_prog = edge_program.to_executorch(
        config=ExecutorchBackendConfig(extract_delegate_segments=False)
    )
    
    print("Saving PTE model...")
    output_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_quantized.pte")
    save_pte_program(exec_prog, output_path)
    
    print(f"Successfully exported model to: {output_path}")
    return output_path


if __name__ == "__main__":
    try:
        pte_path = export_pte_example()
        print("Export completed successfully!")
    except Exception as e:
        print(f"Export failed: {e}")
        raise

I get an output graph that can entirely be delegated to U55.

When I change prepare_input_data to this:

def prepare_input_data(single_sample=False):
    """Prepare dummy input data in channels-first format."""
    dummy_data_nhwc = torch.rand(100, 49, 10, 1)
    
    # Transpose from channels-last to channels-first
    axes = [0, -1] + list(range(1, len(dummy_data_nhwc.shape) - 1))
    dummy_data_nchw = dummy_data_nhwc.permute(axes).contiguous()
    
    if single_sample:
        dummy_data_nchw = dummy_data_nchw[0:1]
    
    return (dummy_data_nchw,)

I get an output graph containing a as_strided_copy operator caused by the AdaptiveAvgPool2d layer which can not be delegated to the U55.

It is my understanding that contiguous() should materialize the calibration data as an NCHW tensor in memory, so I would expect both code snippets to produce the same output model.

Not sure of this is an ExecuTorch or a torch.export() issue. The first step where I noticed a divergence in intermediate results was after the call to to_edge_transform_and_lower().

w. permute operation

Image

w/o permute operation

Image

Versions

PyTorch version: 2.9.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.8 (++20240731025043+3b5b5c1ec4a3-1~exp1~20240731145144.92)
CMake version: version 4.1.2
Libc version: glibc-2.39

Python version: 3.12.3 (main, Nov  6 2025, 13:44:16) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               12
On-line CPU(s) list:                  0-11
Vendor ID:                            GenuineIntel
Model name:                           13th Gen Intel(R) Core(TM) i7-1365U
CPU family:                           6
Model:                                186
Thread(s) per core:                   2
Core(s) per socket:                   6
Socket(s):                            1
Stepping:                             3
BogoMIPS:                             5375.89
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni vnmi umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    Microsoft
Virtualization type:                  full
L1d cache:                            288 KiB (6 instances)
L1i cache:                            192 KiB (6 instances)
L2 cache:                             7.5 MiB (6 instances)
L3 cache:                             12 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-11
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect

ExecuTorch version is 1.0.1

cc @JacobSzwejbka @angelayi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: armIssues related to arm backendmodule: exirIssues related to Export IR and the code under exir/

    Type

    No type

    Projects

    Status

    To triage

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions