Skip to content

Commit bed504e

Browse files
authored
add custom annoatation for new model export
Differential Revision: D77569950 Pull Request resolved: #12126
1 parent 5cc5421 commit bed504e

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,40 @@ def annotate_single_in_single_out(
227227
_annotated=True,
228228
)
229229

230+
def annotate_single_in_share_out(
231+
node: Node, quantization_config: QuantizationConfig
232+
) -> None:
233+
234+
input_qspec_map = {}
235+
input_act = node.args[0]
236+
input_qspec_map[input_act] = quantization_config.input_activation
237+
238+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
239+
input_qspec_map=input_qspec_map,
240+
output_qspec=SharedQuantizationSpec((input_act, node)),
241+
_annotated=True,
242+
)
243+
244+
def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
245+
input_nodes = node.args[0]
246+
247+
first_input_node = input_nodes[0]
248+
input_qspec_map = {}
249+
input_qspec_map[first_input_node] = quantization_config.input_activation
250+
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
251+
(first_input_node, node)
252+
)
253+
254+
for input_node in input_nodes[1:]:
255+
if input_node not in input_qspec_map:
256+
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
257+
258+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
259+
input_qspec_map=input_qspec_map,
260+
output_qspec=share_qparams_with_input_act0_qspec,
261+
_annotated=True,
262+
)
263+
230264
def annotate_matmul_input1(node: Node):
231265
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
232266
act_symmetric=True, act_observer=MinMaxObserver
@@ -247,6 +281,12 @@ def annotate_matmul_input1(node: Node):
247281
]:
248282
annotate_single_in_single_out(node, quantization_config_8a8w)
249283
node = node.args[0]
284+
elif node.target == torch.ops.aten.stack.default:
285+
annotate_stack(node, quantization_config_8a8w)
286+
node = node.args[0]
287+
elif node.target == torch.ops.aten.flatten.using_ints:
288+
annotate_single_in_share_out(node, quantization_config_8a8w)
289+
node = node.args[0]
250290
elif node.target == torch.ops.aten.cat.default:
251291
annotate_cat(node, quantization_config_8a8w)
252292
# For v, we tag 8a until conv op.

0 commit comments

Comments
 (0)