Skip to content

Commit 51befee

Browse files
mansnilsper
andauthored
Arm backend: Enable test_llama_tosa_BI and related fixes (#10681)
First problem solved by adding quantize of scalar_tensor: The where.self operator got a scalar_tensor input which was not quantized. This happened since the where.self quantization annotator uses the parent specs, which in this case where non-existing. Adding the quantization of scalar_tensor sorts this out. Secondly when quantizing scalar_tensor the following assert triggers: expecting kwargs for aten op IR to be empty Hence setting scalar_tensor kwargs to {}. Finally trying to quantize -inf fails for scalar_tensor nodes fails Fix it by adding the pass from qnn backend to replace -inf/inf. Hence adding new pass ReplaceInfValues. Co-authored-by: Per Åstrand <[email protected]>
1 parent 6da46fb commit 51befee

File tree

5 files changed

+54
-3
lines changed

5 files changed

+54
-3
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@
5757
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
5858
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
5959
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
60+
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
6061
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
MatchWhereSelfDtypePass,
5050
QuantizeOperatorArguments,
5151
RemoveClonePass,
52+
ReplaceInfValues,
5253
ReplaceScalarWithTensorArgPassTOSABI,
5354
ReplaceScalarWithTensorArgPassTOSAMI,
5455
RetraceFoldedDtypesPass,
@@ -216,4 +217,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
216217
self.add_pass(DecomposeSoftmaxPass())
217218

218219
self.add_pass(ConvertMinMaxPass())
220+
self.add_pass(ReplaceInfValues())
219221
return self._transform(graph_module)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This pass is based on backends/qualcomm/_passes/replace_inf_values.py
8+
# with some modification to replaced inf values.
9+
10+
import torch
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class ReplaceInfValues(ExportPass):
15+
"""
16+
Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values.
17+
"""
18+
19+
def __init__(self):
20+
super(ReplaceInfValues, self).__init__()
21+
22+
def call(self, graph_module: torch.fx.GraphModule):
23+
modified = False
24+
for buf_name, tensor in graph_module.named_buffers():
25+
if tensor.is_floating_point():
26+
modified = True
27+
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
28+
tensor[tensor == float("inf")] = 255
29+
tensor[tensor == float("-inf")] = -255
30+
setattr(graph_module, buf_name, tensor)
31+
32+
for node in graph_module.graph.nodes:
33+
arg_list = list(node.args)
34+
for index, arg in enumerate(arg_list):
35+
if arg == float("-inf"):
36+
modified = True
37+
arg_list[index] = -255
38+
elif arg == float("inf"):
39+
modified = True
40+
arg_list[index] = +255
41+
node.args = tuple(arg_list)
42+
43+
if modified:
44+
graph_module.recompile()
45+
return PassResult(graph_module, modified)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,9 @@ def any_or_hardtanh_min_zero(n: Node):
411411
shared_qspec = SharedQuantizationSpec(node.args[0])
412412
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
413413
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
414+
elif node.target in [torch.ops.aten.scalar_tensor.default]:
415+
quant_properties.quant_inputs = []
416+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
414417
else:
415418
return None
416419

@@ -458,5 +461,6 @@ def annotate_graph( # type: ignore[return]
458461
if node.target in [
459462
torch.ops.aten.full_like.default,
460463
torch.ops.aten.full.default,
464+
torch.ops.aten.scalar_tensor.default,
461465
]:
462466
node.kwargs = {}

backends/arm/test/models/test_llama.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def test_llama_tosa_MI(self):
105105
)
106106
)
107107

108-
@pytest.mark.xfail(reason="KeyError: scalar_tensor_1 (MLETORCH-907)")
109108
def test_llama_tosa_BI(self):
110109
llama_model, llama_inputs, llama_meta = self.prepare_model()
111110

@@ -126,7 +125,7 @@ def test_llama_tosa_BI(self):
126125
.to_executorch()
127126
.run_method_and_compare_outputs(
128127
inputs=llama_inputs,
129-
atol=4.3,
130-
rtol=1.1, # TODO: Tolerance needs to be updated after MLETORCH-907
128+
atol=9.9,
129+
rtol=1.5, # TODO: Tolerance needs to be updated after MLETORCH-907
131130
)
132131
)

0 commit comments

Comments
 (0)