34
34
35
35
# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
36
36
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa
37
+ from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout , QDQLayout
38
+ from torchao .experimental .quant_api import EmbeddingQuantizer
39
+ from torchao .quantization .granularity import PerAxis , PerGroup
37
40
from torchao .quantization .quant_api import (
38
41
int4_weight_only ,
39
42
Int4WeightOnlyQuantizer ,
40
43
Int8DynActInt4WeightQuantizer ,
44
+ Int8DynamicActivationIntxWeightConfig ,
45
+ MappingType ,
41
46
quantize_ ,
42
47
)
43
48
from torchao .utils import unwrap_tensor_subclass
50
55
state_dict_device ,
51
56
use_et_backend ,
52
57
)
53
- from torchao .experimental .packed_linear_int8_dynamic_activation_intx_weight_layout import (
54
- PackedLinearInt8DynamicActivationIntxWeightLayout ,
55
- )
56
- from torchao .experimental .quant_api import (
57
- int8_dynamic_activation_intx_weight ,
58
- IntxWeightEmbeddingQuantizer ,
59
- )
60
- from torchao .quantization .granularity import (
61
- PerGroup ,
62
- PerRow ,
63
- )
64
- from torchao .dtypes import PlainLayout
65
58
66
59
67
60
# Flag for whether the a8wxdq quantizer is available.
@@ -87,7 +80,7 @@ def get_named_parameters(func: Callable) -> List[str]:
87
80
return named_params
88
81
89
82
def validate_args (named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None ) -> Dict [str , Any ]:
90
- for key in q_kwargs .keys ():
83
+ for key in list ( q_kwargs .keys () ):
91
84
if key not in named_params :
92
85
print (f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring." )
93
86
del q_kwargs [key ]
@@ -137,29 +130,34 @@ def quantize_model(
137
130
group_size = q_kwargs ["groupsize" ]
138
131
bit_width = q_kwargs ["bitwidth" ]
139
132
has_weight_zeros = q_kwargs ["has_weight_zeros" ]
140
- granularity = PerRow () if group_size == - 1 else PerGroup (group_size )
133
+ granularity = PerAxis () if group_size == - 1 else PerGroup (group_size )
141
134
weight_dtype = getattr (torch , f"int{ bit_width } " )
135
+ weight_mapping_type = (
136
+ MappingType .ASYMMETRIC
137
+ if has_weight_zeros
138
+ else MappingType .SYMMETRIC
139
+ )
142
140
143
141
try :
144
142
quantize_ (
145
- model ,
146
- int8_dynamic_activation_intx_weight (
143
+ model ,
144
+ Int8DynamicActivationIntxWeightConfig (
147
145
weight_dtype = weight_dtype ,
148
- granularity = granularity ,
149
- has_weight_zeros = has_weight_zeros ,
146
+ weight_granularity = granularity ,
147
+ weight_mapping_type = weight_mapping_type ,
150
148
layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
151
149
),
152
150
)
153
151
except Exception as e :
154
152
print ("Encountered error during quantization: {e}" )
155
- print ("Trying with PlainLayout " )
153
+ print ("Trying with QDQLayout " )
156
154
quantize_ (
157
- model ,
158
- int8_dynamic_activation_intx_weight (
155
+ model ,
156
+ Int8DynamicActivationIntxWeightConfig (
159
157
weight_dtype = weight_dtype ,
160
- granularity = granularity ,
161
- has_weight_zeros = has_weight_zeros ,
162
- layout = PlainLayout (),
158
+ weight_granularity = granularity ,
159
+ weight_mapping_type = weight_mapping_type ,
160
+ layout = QDQLayout (),
163
161
),
164
162
)
165
163
@@ -174,6 +172,22 @@ def quantize_model(
174
172
print (f"Quantizer { quantizer } requires float32 inputs, but received { get_precision ()} . Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." )
175
173
set_precision (torch .float32 )
176
174
175
+ group_size = q_kwargs ["groupsize" ]
176
+ bit_width = q_kwargs ["bitwidth" ]
177
+ has_weight_zeros = q_kwargs .get ("has_weight_zeros" , True )
178
+ q_kwargs ["granularity" ] = (
179
+ PerAxis () if group_size == - 1 else PerGroup (group_size )
180
+ )
181
+ q_kwargs ["weight_dtype" ] = getattr (torch , f"int{ bit_width } " )
182
+ q_kwargs ["mapping_type" ] = (
183
+ MappingType .ASYMMETRIC
184
+ if has_weight_zeros
185
+ else MappingType .SYMMETRIC
186
+ )
187
+ q_kwargs ["use_fallback" ] = False
188
+ del q_kwargs ["groupsize" ]
189
+ del q_kwargs ["bitwidth" ]
190
+
177
191
if quantizer == "linear:afpwx" and device != "mps" :
178
192
raise RuntimeError ("linear:afpwx quantization can only run on mps device!" )
179
193
@@ -188,7 +202,10 @@ def quantize_model(
188
202
# Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
189
203
if "tokenizer" in named_params :
190
204
q_kwargs ["tokenizer" ] = tokenizer
191
- quant_handler = q (device = device , precision = precision , ** q_kwargs )
205
+ if quantizer == "embedding:wx" :
206
+ quant_handler = q (** q_kwargs )
207
+ else :
208
+ quant_handler = q (device = device , precision = precision , ** q_kwargs )
192
209
193
210
# quantize model
194
211
model = quant_handler .quantize (model )
@@ -939,7 +956,7 @@ def quantized_model(self) -> nn.Module:
939
956
# class references
940
957
quantizer_class_dict = {
941
958
"embedding" : EmbeddingOnlyQuantHandler ,
942
- "embedding:wx" : IntxWeightEmbeddingQuantizer ,
959
+ "embedding:wx" : EmbeddingQuantizer ,
943
960
"linear:int8" : WeightOnlyInt8QuantHandler ,
944
961
"precision" : PrecisionHandler ,
945
962
"executor" : ExecutorHandler ,
@@ -979,5 +996,19 @@ def quantized_model(self) -> nn.Module:
979
996
except Exception as e :
980
997
print ("Unable to load torchao mps ops library." )
981
998
999
+ torchao_experimental_mps_op_lib_spec = importlib .util .spec_from_file_location (
1000
+ "torchao_experimental_mps_op_lib" ,
1001
+ f"{ torchao_build_path } /src/ao/torchao/experimental/ops/mps/mps_op_lib.py" ,
1002
+ )
1003
+ torchao_experimental_mps_op_lib = importlib .util .module_from_spec (
1004
+ torchao_experimental_mps_op_lib_spec
1005
+ )
1006
+ sys .modules ["torchao_experimental_mps_op_lib" ] = torchao_experimental_mps_op_lib
1007
+ torchao_experimental_mps_op_lib_spec .loader .exec_module (
1008
+ torchao_experimental_mps_op_lib
1009
+ )
1010
+ from torchao_experimental_mps_op_lib import *
1011
+
1012
+
982
1013
except Exception as e :
983
1014
print ("Unable to import torchao experimental quant_api with error: " , e )
0 commit comments