1
1
from enum import Enum
2
2
3
3
from tqdm import tqdm
4
- from typing import Set
4
+ from typing import Set , List , Optional
5
5
import onnx
6
6
import os
7
7
@@ -110,6 +110,16 @@ class QuantizationArguments:
110
110
},
111
111
)
112
112
113
+ op_block_list : List [str ] = field (
114
+ default = None ,
115
+ metadata = {
116
+ "help" : "List of operators to exclude from quantization."
117
+ "Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)"
118
+ "or your custom implemented operators." ,
119
+ "nargs" : "+" ,
120
+ },
121
+ )
122
+
113
123
114
124
def get_operators (model : onnx .ModelProto ) -> Set [str ]:
115
125
operators = set ()
@@ -131,6 +141,7 @@ def quantize_q8(
131
141
per_channel : bool ,
132
142
reduce_range : bool ,
133
143
weight_type : QuantType ,
144
+ op_block_list : Optional [List [str ]]
134
145
):
135
146
"""
136
147
Quantize the weights of the model from float32 to int8/uint8
@@ -140,6 +151,10 @@ def quantize_q8(
140
151
it is faster on most CPU architectures
141
152
"""
142
153
154
+ op_types_to_quantize = set (IntegerOpsRegistry .keys ())
155
+ if op_block_list is not None :
156
+ op_types_to_quantize .difference_update (op_block_list )
157
+
143
158
quantizer = ONNXQuantizer (
144
159
model ,
145
160
per_channel ,
@@ -151,7 +166,7 @@ def quantize_q8(
151
166
tensors_range = None ,
152
167
nodes_to_quantize = [],
153
168
nodes_to_exclude = [],
154
- op_types_to_quantize = list ( IntegerOpsRegistry . keys ()) ,
169
+ op_types_to_quantize = op_types_to_quantize ,
155
170
extra_options = dict (
156
171
EnableSubgraph = True ,
157
172
MatMulConstBOnly = True ,
@@ -165,6 +180,7 @@ def quantize_q8(
165
180
def quantize_fp16 (
166
181
model : onnx .ModelProto ,
167
182
save_path : str ,
183
+ op_block_list : Optional [List [str ]]
168
184
):
169
185
"""
170
186
Quantize the weights of the model from float32 to float16
@@ -174,10 +190,15 @@ def quantize_fp16(
174
190
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
175
191
disable_shape_infer = model .ByteSize () >= onnx .checker .MAXIMUM_PROTOBUF
176
192
193
+ blocked_ops = set (float16 .DEFAULT_OP_BLOCK_LIST )
194
+ if op_block_list is not None :
195
+ blocked_ops .update (op_block_list )
196
+
177
197
model_fp16 = float16 .convert_float_to_float16 (
178
198
model ,
179
199
keep_io_types = True ,
180
200
disable_shape_infer = disable_shape_infer ,
201
+ op_block_list = blocked_ops ,
181
202
)
182
203
graph = gs .import_onnx (model_fp16 )
183
204
graph .toposort ()
@@ -271,6 +292,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
271
292
quantize_fp16 (
272
293
model ,
273
294
save_path ,
295
+ quantization_args .op_block_list
274
296
)
275
297
276
298
elif mode in (QuantMode .Q4 , QuantMode .Q4F16 ):
@@ -287,6 +309,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
287
309
quantize_fp16 (
288
310
q4_model ,
289
311
save_path ,
312
+ quantization_args .op_block_list ,
290
313
)
291
314
292
315
elif mode == QuantMode .BNB4 :
@@ -331,6 +354,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
331
354
per_channel = quantization_args .per_channel ,
332
355
reduce_range = quantization_args .reduce_range ,
333
356
weight_type = weight_type ,
357
+ op_block_list = quantization_args .op_block_list ,
334
358
)
335
359
336
360
0 commit comments