Skip to content

Commit 5272b12

Browse files
pdufourxenova
andauthored
Add support for --op_block_list in quantization script (huggingface#1036)
* Add support for op_block_list * Remove arg * Set default to none * Minor code suggestions * whoops - actually apply suggestions --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent c9f12e5 commit 5272b12

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ __pycache__
22
.vscode
33
node_modules
44
.cache
5+
.DS_STORE
56

67
# Do not track build artifacts/generated files
78
/dist

scripts/quantize.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22

33
from tqdm import tqdm
4-
from typing import Set
4+
from typing import Set, List, Optional
55
import onnx
66
import os
77

@@ -110,6 +110,16 @@ class QuantizationArguments:
110110
},
111111
)
112112

113+
op_block_list: List[str] = field(
114+
default=None,
115+
metadata={
116+
"help": "List of operators to exclude from quantization."
117+
"Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)"
118+
"or your custom implemented operators.",
119+
"nargs": "+",
120+
},
121+
)
122+
113123

114124
def get_operators(model: onnx.ModelProto) -> Set[str]:
115125
operators = set()
@@ -131,6 +141,7 @@ def quantize_q8(
131141
per_channel: bool,
132142
reduce_range: bool,
133143
weight_type: QuantType,
144+
op_block_list: Optional[List[str]]
134145
):
135146
"""
136147
Quantize the weights of the model from float32 to int8/uint8
@@ -140,6 +151,10 @@ def quantize_q8(
140151
it is faster on most CPU architectures
141152
"""
142153

154+
op_types_to_quantize = set(IntegerOpsRegistry.keys())
155+
if op_block_list is not None:
156+
op_types_to_quantize.difference_update(op_block_list)
157+
143158
quantizer = ONNXQuantizer(
144159
model,
145160
per_channel,
@@ -151,7 +166,7 @@ def quantize_q8(
151166
tensors_range=None,
152167
nodes_to_quantize=[],
153168
nodes_to_exclude=[],
154-
op_types_to_quantize=list(IntegerOpsRegistry.keys()),
169+
op_types_to_quantize=op_types_to_quantize,
155170
extra_options=dict(
156171
EnableSubgraph=True,
157172
MatMulConstBOnly=True,
@@ -165,6 +180,7 @@ def quantize_q8(
165180
def quantize_fp16(
166181
model: onnx.ModelProto,
167182
save_path: str,
183+
op_block_list: Optional[List[str]]
168184
):
169185
"""
170186
Quantize the weights of the model from float32 to float16
@@ -174,10 +190,15 @@ def quantize_fp16(
174190
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
175191
disable_shape_infer = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF
176192

193+
blocked_ops = set(float16.DEFAULT_OP_BLOCK_LIST)
194+
if op_block_list is not None:
195+
blocked_ops.update(op_block_list)
196+
177197
model_fp16 = float16.convert_float_to_float16(
178198
model,
179199
keep_io_types=True,
180200
disable_shape_infer=disable_shape_infer,
201+
op_block_list=blocked_ops,
181202
)
182203
graph = gs.import_onnx(model_fp16)
183204
graph.toposort()
@@ -271,6 +292,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
271292
quantize_fp16(
272293
model,
273294
save_path,
295+
quantization_args.op_block_list
274296
)
275297

276298
elif mode in (QuantMode.Q4, QuantMode.Q4F16):
@@ -287,6 +309,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
287309
quantize_fp16(
288310
q4_model,
289311
save_path,
312+
quantization_args.op_block_list,
290313
)
291314

292315
elif mode == QuantMode.BNB4:
@@ -331,6 +354,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
331354
per_channel=quantization_args.per_channel,
332355
reduce_range=quantization_args.reduce_range,
333356
weight_type=weight_type,
357+
op_block_list=quantization_args.op_block_list,
334358
)
335359

336360

0 commit comments

Comments
 (0)