Skip to content

MatMul quantisation fails with XNNPack backend #8138

Open
@metinsuloglu

Description

@metinsuloglu

🐛 Describe the bug

XNNPackQuantizer silently fails to quantise torch.matmul.
I see that the model is not quantised when printing it or profiling with xnn_executor_runner.

import torch
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer
)
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
    XnnpackPartitioner
)


class MatMulModel(torch.nn.Module):
    """ torch.matmul """
    def forward(self, inputs):
        return torch.matmul(*inputs)


class AddModel(torch.nn.Module):
    """ torch.add """
    def forward(self, inputs):
        return torch.add(*inputs)


def quantize(model, example_inputs):
    print(f"Original model: {model}")

    quantizer = XNNPACKQuantizer()
    operator_config = get_symmetric_quantization_config(is_per_channel=False)
    quantizer.set_global(operator_config)

    m = prepare_pt2e(model, quantizer)
    m(*example_inputs)
    m = convert_pt2e(m)

    print(f"Quantized model: {m}")
    return m


def export(model, example_inputs, fname):
    model = torch.export.export_for_training(model, example_inputs).module()

    quantized_model = quantize(model, example_inputs)

    edge = to_edge_transform_and_lower(
        torch.export.export(quantized_model, example_inputs),
        compile_config=EdgeCompileConfig(_check_ir_validity=False),
        partitioner=[XnnpackPartitioner()]
    )
    exec_prog = edge.to_executorch()

    with open(fname, "wb") as file:
        exec_prog.write_to_file(file)


### MatMul ###
print("\ntorch.matmul:\n")
model = MatMulModel().eval()
sample_inputs = ((torch.randn((1, 4, 16, 16)), torch.randn((16, 16))),)

export(model, sample_inputs, "qs8_xnnpack_matmul.pte")

### Add ###
print("\ntorch.add:\n")
model = AddModel().eval()
sample_inputs = ((torch.randn((1, 4, 16, 16)), torch.randn((16, 16))),)

export(model, sample_inputs, "qs8_xnnpack_add.pte")

Output:

torch.matmul:

Original model: GraphModule()

def forward(self, inputs):
    inputs_0, inputs_1, = fx_pytree.tree_flatten_spec(([inputs], {}), self._in_spec)
    matmul = torch.ops.aten.matmul.default(inputs_0, inputs_1);  inputs_0 = inputs_1 = None
    return pytree.tree_unflatten((matmul,), self._out_spec)

Quantized model: GraphModule()

def forward(self, inputs):
    inputs_0, inputs_1, = fx_pytree.tree_flatten_spec(([inputs], {}), self._in_spec)
    matmul = torch.ops.aten.matmul.default(inputs_0, inputs_1);  inputs_0 = inputs_1 = None
    return pytree.tree_unflatten((matmul,), self._out_spec)


torch.add:

Original model: GraphModule()

def forward(self, inputs):
    inputs_0, inputs_1, = fx_pytree.tree_flatten_spec(([inputs], {}), self._in_spec)
    add = torch.ops.aten.add.Tensor(inputs_0, inputs_1);  inputs_0 = inputs_1 = None
    return pytree.tree_unflatten((add,), self._out_spec)
    
Quantized model: GraphModule()

def forward(self, inputs):
    inputs_0, inputs_1, = fx_pytree.tree_flatten_spec(([inputs], {}), self._in_spec)
    quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(inputs_0, 0.028314031660556793, 10, -128, 127, torch.int8);  inputs_0 = None
    dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 0.028314031660556793, 10, -128, 127, torch.int8);  quantize_per_tensor_default = None
    quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(inputs_1, 0.02587413415312767, 33, -128, 127, torch.int8);  inputs_1 = None
    dequantize_per_tensor_default_1 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_1, 0.02587413415312767, 33, -128, 127, torch.int8);  quantize_per_tensor_default_1 = None
    add = torch.ops.aten.add.Tensor(dequantize_per_tensor_default, dequantize_per_tensor_default_1);  dequantize_per_tensor_default = dequantize_per_tensor_default_1 = None
    quantize_per_tensor_default_2 = torch.ops.quantized_decomposed.quantize_per_tensor.default(add, 0.0323091559112072, 14, -128, 127, torch.int8);  add = None
    dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 0.0323091559112072, 14, -128, 127, torch.int8);  quantize_per_tensor_default_2 = None
    return pytree.tree_unflatten((dequantize_per_tensor_default_2,), self._out_spec)

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20250104+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.31.4
Libc version: glibc-2.39

Python version: 3.10.16 (main, Dec 4 2024, 08:53:38) [GCC 13.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 550.120
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i9-10900X CPU @ 3.70GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 7
CPU(s) scaling MHz: 30%
CPU max MHz: 4700.0000
CPU min MHz: 1200.0000
BogoMIPS: 7399.70
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req vnmi avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 320 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 10 MiB (10 instances)
L3 cache: 19.3 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-19
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+a5c7609
[pip3] numpy==2.2.2
[pip3] torch==2.6.0.dev20250104+cpu
[pip3] torchao==0.8.0+git11333ba2
[pip3] torchaudio==2.6.0.dev20250104+cpu
[pip3] torchsr==1.0.4
[pip3] torchvision==0.22.0.dev20250104+cpu
[pip3] triton==3.2.0
[conda] Could not collect

cc @digantdesai @mcr229 @cbilgin @freddan80 @per @zingo @oscarandersson8218

Metadata

Metadata

Assignees

Labels

module: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/need-user-inputThe issue needs more information from the reporter before moving forwardtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions