Description
🐛 Describe the bug
Example from the https://docs.pytorch.org/executorch/main/backends-coreml.html#quantization does not work:
Python 3.10.18
executorch: 8dde918 (vitable/strict)
Installation instruction:
https://docs.pytorch.org/executorch/main/using-executorch-building-from-source.html#install-executorch-pip-package-from-source
./install_executorch.sh
Quantization script (from the docs):
import torch
import coremltools as ct
import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from executorch.exir import to_edge_transform_and_lower
from executorch.backends.apple.coreml.compiler import CoreMLBackend
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )
# Step 1: Define a LinearQuantizerConfig and create an instance of a CoreMLQuantizer
quantization_config = ct.optimize.torch.quantization.LinearQuantizerConfig.from_dict(
{
"global_config": {
"quantization_scheme": ct.optimize.torch.quantization.QuantizationScheme.symmetric,
"milestones": [0, 0, 10, 10],
"activation_dtype": torch.quint8,
"weight_dtype": torch.qint8,
"weight_per_channel": True,
}
}
)
quantizer = CoreMLQuantizer(quantization_config)
# Step 2: Export the model for training
training_gm = torch.export.export_for_training(mobilenet_v2, sample_inputs).module()
# Step 3: Prepare the model for quantization
prepared_model = prepare_pt2e(training_gm, quantizer)
# Step 4: Calibrate the model on representative data
# Replace with your own calibration data
for calibration_sample in [torch.randn(1, 3, 224, 224)]:
prepared_model(calibration_sample)
# Step 5: Convert the calibrated model to a quantized model
quantized_model = convert_pt2e(prepared_model)
# Step 6: Export the quantized model to CoreML
et_program = to_edge_transform_and_lower(
torch.export.export(quantized_model, sample_inputs),
partitioner=[
CoreMLPartitioner(
# iOS17 is required for the quantized ops in this example
compile_specs=CoreMLBackend.generate_compile_specs(
minimum_deployment_target=ct.target.iOS17
)
)
],
).to_executorch()
model_file_name = "mnv2_int8_coreml.pt"
with open(model_file_name, "wb") as file:
et_program.write_to_file(file)
Validation script (from the https://docs.pytorch.org/executorch/main/backends-coreml.html#quantization)
import torch
from executorch.runtime import Runtime
runtime = Runtime.get()
input_tensor = torch.randn(1, 3, 224, 224)
program = runtime.load_program("mnv2_int8_coreml.pt")
method = program.load_method("forward")
outputs = method.execute([input_tensor])
print(outputs)
Works ok with the FP32 model, but doesn't with the quantized model with following error:
python test_mn.py
Found library at: /Users/devuser/dlyakhov/executorch/env/lib/python3.10/site-packages/torchao/libtorchao_ops_aten.dylib
W0627 16:50:55.298000 69662 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
/Users/devuser/dlyakhov/executorch/env/lib/python3.10/site-packages/executorch/exir/dialects/edge/_ops.py:9: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
import pkg_resources
[program.cpp:135] InternalConsistency verification requested but not available
[ETCoreMLModelManager.mm:528] Cache Miss: Model with identifier=executorch_6559be5c-8b99-48a7-a94d-4252a47a4ea5_all was not found in the models cache.
[ETCoreMLModelCompiler.mm:55] [Core ML] Failed to compile model, error = Error Domain=com.apple.mlassetio Code=1 "Failed to parse the model specification. Error: Unable to parse ML Program: at unknown location: Unknown opset 'CoreML7'." UserInfo={NSLocalizedDescription=Failed to par$
[backend_delegate.mm:288] [Core ML] Model init failed Failed to compile model, error = Error Domain=com.apple.mlassetio Code=1 "Failed to parse the model specification. Error: Unable to parse ML Program: at unknown location: Unknown opset 'CoreML7'." UserInfo={NSLocalizedDescript$
[coreml_backend_delegate.mm:193] CoreMLBackend: Failed to init the model.
[method.cpp:113] Init failed for backend CoreMLBackend: 0x23
Traceback (most recent call last):
File "/Users/devuser/dlyakhov/executorch/examples/openvino/test_mn.py", line 8, in <module>
program = runtime.load_program("mnv2_int8_coreml.pt")
File "/Users/devuser/dlyakhov/executorch/env/lib/python3.10/site-packages/executorch/runtime/__init__.py", line 202, in load_program
m = self._legacy_module._load_for_executorch(
RuntimeError: loading method forward failed with error 0x23
According to the error, the problem is in the quantization script
Versions
PyTorch version: 2.8.0.dev20250601
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.4 (arm64)
GCC version: Could not collect
Clang version: 14.0.3 (clang-1403.0.22.14.1)
CMake version: version 3.31.6
Libc version: N/A
Python version: 3.10.18 (main, Jun 3 2025, 18:23:41) [Clang 15.0.0 (clang-1500.1.0.2.5)] (64-bit runtime)
Python platform: macOS-13.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Pro
Versions of relevant libraries:
[pip3] executorch==0.7.0a0+8dde918
[pip3] numpy==2.2.6
[pip3] pytorch_tokenizers==0.1.0
[pip3] torch==2.8.0.dev20250601
[pip3] torchao==0.12.0+gitbc68b11f
[pip3] torchaudio==2.8.0.dev20250601
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.23.0.dev20