Skip to content

Cannot export quantized resnet50_clip.openai model to Edge dialect #6685

Open
@corehalt

Description

@corehalt

🐛 Describe the bug

After quantizating the resnet50_clip.openai model with torch.ao quantization, last step exir.to_edge() fails quite often, not only with this model, but with many others:

import sys
import timm 
import torch
from executorch import exir

from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  XNNPACKQuantizer,
  get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import (
  prepare_pt2e,
  convert_pt2e,
)

model = timm.create_model('resnet50_clip.openai', pretrained=True)
model.eval()
in_shape = (1,) + model.default_cfg['input_size']
tracing_input = (torch.randn(in_shape),)

quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(
    is_per_channel = True,
    is_dynamic = False,
    act_qmin = -128,
    act_qmax = 127,
    weight_qmin = -127,
    weight_qmax = 127)
)
exported_model = torch.export.export_for_training(model, tracing_input).module()
prepared_model = prepare_pt2e(exported_model, quantizer)
quantized_model = convert_pt2e(prepared_model, use_reference_representation=True)

aten_dialect_program = torch.export.export(quantized_model, tracing_input)
edge_dialect_program = exir.to_edge(aten_dialect_program)
executorch_program = edge_dialect_program.to_executorch()

Produces the error:

Traceback (most recent call last):
  File "repro.py", line 34, in <module>
    edge_dialect_program = exir.to_edge(aten_dialect_program)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/executorch/exir/program/_program.py", line 1143, in to_edge
    edge_programs[name] = _generate_edge_program(name, config, program)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/executorch/exir/program/_program.py", line 738, in _generate_edge_program
    edge_program = ExportedProgram(
                   ^^^^^^^^^^^^^^^^
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/torch/export/exported_program.py", line 700, in __init__
    self.validate()
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/torch/export/exported_program.py", line 1117, in validate
    self._validate()
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/torch/export/exported_program.py", line 1126, in _validate
    v().check(self)
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/torch/_export/verifier.py", line 155, in check
    self._check_graph_module(ep.graph_module)
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/torch/_export/verifier.py", line 268, in _check_graph_module
    self.check_additional(gm)
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/executorch/exir/verification/verifier.py", line 262, in check_additional
    _check_tensor_args_matching_op_allowed_dtype(gm)
  File "/pyenv/versions/executorch-cpu/lib/python3.12/site-packages/executorch/exir/verification/verifier.py", line 180, in _check_tensor_args_matching_op_allowed_dtype
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: These operators are taking Tensor inputs with mismatched dtypes: defaultdict(<class 'dict'>, {<EdgeOpOverload: aten.sub.Tensor>: schema = aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor: {'self': torch.int32, 'other': torch.int64, '__ret_0': torch.int32}})

This error happen with many other models, for example:

beit_base_patch16_224.in22k_ft_in22k
caformer_b36.sail_in1k
deit3_base_patch16_384.fb_in1k
vit_base_patch14_dinov2.lvd142m

Also reported on PyTorch: pytorch/pytorch#139718

Versions

PyTorch version: 2.5.0+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 17.0.6 (https://github.com/llvm/llvm-project.git 6009708b4367171ccdbf4b5905cb6a803753fe18)
CMake version: version 3.30.5
Libc version: glibc-2.35

Python version: 3.12.2 (main, May 16 2024, 10:08:02) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] executorch==0.4.0a0+6a085ff
[pip3] numpy==1.26.4
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.19.2
[pip3] onnxslim==0.1.36
[pip3] torch==2.5.0+cpu
[pip3] torchaudio==2.5.0+cpu
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0+cpu
[conda] Could not collect

cc @JacobSzwejbka @angelayi @kimishpatel @jerryzh168

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: exirIssues related to Export IR and the code under exir/module: quantizationIssues related to quantizationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions