Skip to content

Commit 052291d

Browse files
Add linear as a supported op for activation quantization. (#2577)
1 parent 6398eef commit 052291d

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

coremltools/converters/mil/mil/passes/defs/optimize_activation_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def _try_match_and_transform_pattern(
159159
- (`quantize` ->) dequantize` -> `conv` -> `relu` -> `quantize` -> `dequantize`
160160
"""
161161

162-
# Reject if 1st operation is not `conv`/`add`/`pool`.
163-
SUPPORTED_OP_TYPES = ["conv", "add", "avg_pool", "max_pool"]
162+
# Reject if 1st operation is not `conv`/`add`/`pool`/`linear`.
163+
SUPPORTED_OP_TYPES = ["conv", "add", "avg_pool", "max_pool", "linear"]
164164
if any(_check_child_op_type(dequantize_op, val) for val in SUPPORTED_OP_TYPES):
165165
pass
166166
else:

coremltools/optimize/coreml/_quantization_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1640,7 +1640,7 @@ class insert_prefix_quantize_dequantize_pair(AbstractActCompressionPass):
16401640

16411641
_SUPPORTED_CONFIG_TYPE = OpLinearQuantizerConfig
16421642

1643-
SUPPORTED_UNARY_OP_TYPES = ["conv", "avg_pool", "max_pool"]
1643+
SUPPORTED_UNARY_OP_TYPES = ["conv", "avg_pool", "max_pool", "linear"]
16441644
SUPPORTED_BINARY_OP_TYPES = ["add"]
16451645
SUPPORTED_OP_TYPES = SUPPORTED_UNARY_OP_TYPES + SUPPORTED_BINARY_OP_TYPES
16461646

coremltools/test/optimize/coreml/test_passes.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,21 @@ def prog(x):
900900

901901
return prog
902902

903+
@staticmethod
904+
def _get_test_program_linear():
905+
"""An iOS17 program with linear."""
906+
@mb.program(
907+
input_specs=[mb.TensorSpec(shape=(30, 10))], opset_version=ct.target.iOS17
908+
)
909+
def prog(x):
910+
linear_weight = np.random.rand(30, 10).astype(np.float32)
911+
x = mb.cast(x=x, dtype="fp16")
912+
x = mb.linear(x=x, weight=linear_weight)
913+
x = mb.cast(x=x, dtype="fp32")
914+
return x
915+
916+
return prog
917+
903918
@staticmethod
904919
def _get_test_mlmodel_conv_concat():
905920
"""A mlmodel has a concat with 2 inputs and 1 output all surrounded by conv."""
@@ -3710,6 +3725,57 @@ def test_global_config_activation_quantizer_on_pattern_3(self, mode, dtype, weig
37103725
"cast",
37113726
]
37123727

3728+
@pytest.mark.parametrize(
3729+
"mode, dtype, weight_threshold",
3730+
itertools.product(
3731+
["LINEAR", "LINEAR_SYMMETRIC"],
3732+
[np.int8, np.uint8, types.int8, types.uint8],
3733+
[1000],
3734+
),
3735+
)
3736+
def test_global_config_activation_quantizer_on_pattern_4(self, mode, dtype, weight_threshold):
3737+
"""
3738+
Global config would compress all operations with the same config
3739+
Valid patterns:
3740+
- linear
3741+
"""
3742+
3743+
# Insert prefix quantize/dequantize pairs
3744+
op_config = cto.coreml.OpLinearQuantizerConfig(
3745+
mode=mode, dtype=dtype, weight_threshold=weight_threshold
3746+
)
3747+
config = cto.coreml.OptimizationConfig(global_config=op_config)
3748+
3749+
# Test case: conv
3750+
prog = self._get_test_program_linear()
3751+
3752+
# Create activation_stats to all intermediate tensors
3753+
activation_stats = gen_activation_stats_for_program(prog)
3754+
3755+
# Insert prefix quantize/dequantize pairs
3756+
graph_pass_1 = _insert_prefix_quantize_dequantize_pair(config)
3757+
graph_pass_1.set_options([PassOption("activation_stats", activation_stats)])
3758+
3759+
# Insert suffix quantize/dequantize pairs
3760+
graph_pass_2 = PASS_REGISTRY["compression::insert_suffix_quantize_dequantize_pair"]
3761+
graph_pass_2.set_options(
3762+
[PassOption("config", config), PassOption("activation_stats", activation_stats)]
3763+
)
3764+
3765+
apply_pass_and_basic_check(prog, graph_pass_1)
3766+
apply_pass_and_basic_check(prog, graph_pass_2)
3767+
3768+
print(get_op_types_in_program(prog))
3769+
assert get_op_types_in_program(prog) == [
3770+
"cast",
3771+
"quantize",
3772+
"dequantize",
3773+
"linear",
3774+
"quantize",
3775+
"dequantize",
3776+
"cast",
3777+
]
3778+
37133779

37143780
class TestGetActivationStats(TestCompressionPasses):
37153781
def test_get_activation_calibration_stats_basic(self):

0 commit comments

Comments
 (0)