@@ -900,6 +900,21 @@ def prog(x):
900900
901901 return prog
902902
903+ @staticmethod
904+ def _get_test_program_linear ():
905+ """An iOS17 program with linear."""
906+ @mb .program (
907+ input_specs = [mb .TensorSpec (shape = (30 , 10 ))], opset_version = ct .target .iOS17
908+ )
909+ def prog (x ):
910+ linear_weight = np .random .rand (30 , 10 ).astype (np .float32 )
911+ x = mb .cast (x = x , dtype = "fp16" )
912+ x = mb .linear (x = x , weight = linear_weight )
913+ x = mb .cast (x = x , dtype = "fp32" )
914+ return x
915+
916+ return prog
917+
903918 @staticmethod
904919 def _get_test_mlmodel_conv_concat ():
905920 """A mlmodel has a concat with 2 inputs and 1 output all surrounded by conv."""
@@ -3710,6 +3725,57 @@ def test_global_config_activation_quantizer_on_pattern_3(self, mode, dtype, weig
37103725 "cast" ,
37113726 ]
37123727
3728+ @pytest .mark .parametrize (
3729+ "mode, dtype, weight_threshold" ,
3730+ itertools .product (
3731+ ["LINEAR" , "LINEAR_SYMMETRIC" ],
3732+ [np .int8 , np .uint8 , types .int8 , types .uint8 ],
3733+ [1000 ],
3734+ ),
3735+ )
3736+ def test_global_config_activation_quantizer_on_pattern_4 (self , mode , dtype , weight_threshold ):
3737+ """
3738+ Global config would compress all operations with the same config
3739+ Valid patterns:
3740+ - linear
3741+ """
3742+
3743+ # Insert prefix quantize/dequantize pairs
3744+ op_config = cto .coreml .OpLinearQuantizerConfig (
3745+ mode = mode , dtype = dtype , weight_threshold = weight_threshold
3746+ )
3747+ config = cto .coreml .OptimizationConfig (global_config = op_config )
3748+
3749+ # Test case: conv
3750+ prog = self ._get_test_program_linear ()
3751+
3752+ # Create activation_stats to all intermediate tensors
3753+ activation_stats = gen_activation_stats_for_program (prog )
3754+
3755+ # Insert prefix quantize/dequantize pairs
3756+ graph_pass_1 = _insert_prefix_quantize_dequantize_pair (config )
3757+ graph_pass_1 .set_options ([PassOption ("activation_stats" , activation_stats )])
3758+
3759+ # Insert suffix quantize/dequantize pairs
3760+ graph_pass_2 = PASS_REGISTRY ["compression::insert_suffix_quantize_dequantize_pair" ]
3761+ graph_pass_2 .set_options (
3762+ [PassOption ("config" , config ), PassOption ("activation_stats" , activation_stats )]
3763+ )
3764+
3765+ apply_pass_and_basic_check (prog , graph_pass_1 )
3766+ apply_pass_and_basic_check (prog , graph_pass_2 )
3767+
3768+ print (get_op_types_in_program (prog ))
3769+ assert get_op_types_in_program (prog ) == [
3770+ "cast" ,
3771+ "quantize" ,
3772+ "dequantize" ,
3773+ "linear" ,
3774+ "quantize" ,
3775+ "dequantize" ,
3776+ "cast" ,
3777+ ]
3778+
37133779
37143780class TestGetActivationStats (TestCompressionPasses ):
37153781 def test_get_activation_calibration_stats_basic (self ):
0 commit comments