Skip to content

Commit 1ca3252

Browse files
SaoirseARMzingo
andauthored
Arm backend: Add LSTM test for int16x8 (#15524)
### Summary Adds testcases to LSTM model unit test for in16x8 support cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Saoirse Stewart <[email protected]> Co-authored-by: Zingo Andersen <[email protected]>
1 parent 50f6b0c commit 1ca3252

File tree

1 file changed

+77
-1
lines changed

1 file changed

+77
-1
lines changed

backends/arm/test/models/test_lstm_arm.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55

66
from typing import Tuple
77

8+
import pytest
89
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
get_symmetric_a16w8_quantization_config,
12+
TOSAQuantizer,
13+
)
914

10-
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test import common, conftest
1116
from executorch.backends.arm.test.tester.test_pipeline import (
1217
EthosU55PipelineINT,
1318
EthosU85PipelineINT,
@@ -16,6 +21,9 @@
1621
VgfPipeline,
1722
)
1823

24+
from executorch.backends.arm.tosa import TosaSpecification
25+
from executorch.backends.xnnpack.test.tester import Quantize
26+
1927
from torch.nn.quantizable.modules import rnn
2028

2129
input_t = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] # (h0, c0)
@@ -134,3 +142,71 @@ def test_lstm_vgf_FP():
134142
use_to_edge_transform_and_lower=True,
135143
)
136144
pipeline.run()
145+
146+
147+
def get_symmetric_a16w8_lstm_quantizer(per_channel_quantization=False):
148+
tosa_version = conftest.get_option("tosa_version")
149+
tosa_profiles = {
150+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
151+
}
152+
153+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
154+
quantizer.set_global(
155+
get_symmetric_a16w8_quantization_config(
156+
is_per_channel=per_channel_quantization, epsilon=2**-16
157+
)
158+
)
159+
160+
return Quantize(
161+
quantizer,
162+
get_symmetric_a16w8_quantization_config(
163+
is_per_channel=per_channel_quantization, epsilon=2**-16
164+
),
165+
)
166+
167+
168+
def test_lstm_16a8w_tosa_INT():
169+
"""Test LSTM model with 16A8W quantization (16-bit activations, 8-bit weights)"""
170+
171+
pipeline = TosaPipelineINT[input_t](
172+
TestLSTM.lstm,
173+
TestLSTM.model_example_inputs,
174+
aten_op=[],
175+
exir_op=[],
176+
per_channel_quantization=False,
177+
use_to_edge_transform_and_lower=True,
178+
tosa_extensions=["int16"],
179+
)
180+
181+
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
182+
pipeline.run()
183+
184+
185+
@pytest.mark.xfail(
186+
reason="MLETORCH-1452: AssertionError: Output 0 does not match reference output."
187+
)
188+
@common.XfailIfNoCorstone300
189+
def test_lstm_16a8w_u55_INT():
190+
pipeline = EthosU55PipelineINT[input_t](
191+
TestLSTM.lstm,
192+
TestLSTM.model_example_inputs,
193+
aten_ops=[],
194+
exir_ops=[],
195+
use_to_edge_transform_and_lower=True,
196+
)
197+
198+
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
199+
pipeline.run()
200+
201+
202+
@common.XfailIfNoCorstone320
203+
def test_lstm_16a8w_u85_INT():
204+
pipeline = EthosU85PipelineINT[input_t](
205+
TestLSTM.lstm,
206+
TestLSTM.model_example_inputs,
207+
aten_ops=[],
208+
exir_ops=[],
209+
use_to_edge_transform_and_lower=True,
210+
)
211+
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
212+
pipeline.run()

0 commit comments

Comments
 (0)