@@ -227,6 +227,40 @@ def annotate_single_in_single_out(
227
227
_annotated = True ,
228
228
)
229
229
230
+ def annotate_single_in_share_out (
231
+ node : Node , quantization_config : QuantizationConfig
232
+ ) -> None :
233
+
234
+ input_qspec_map = {}
235
+ input_act = node .args [0 ]
236
+ input_qspec_map [input_act ] = quantization_config .input_activation
237
+
238
+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
239
+ input_qspec_map = input_qspec_map ,
240
+ output_qspec = SharedQuantizationSpec ((input_act , node )),
241
+ _annotated = True ,
242
+ )
243
+
244
+ def annotate_stack (node : Node , quantization_config : QuantizationConfig ) -> None :
245
+ input_nodes = node .args [0 ]
246
+
247
+ first_input_node = input_nodes [0 ]
248
+ input_qspec_map = {}
249
+ input_qspec_map [first_input_node ] = quantization_config .input_activation
250
+ share_qparams_with_input_act0_qspec = SharedQuantizationSpec (
251
+ (first_input_node , node )
252
+ )
253
+
254
+ for input_node in input_nodes [1 :]:
255
+ if input_node not in input_qspec_map :
256
+ input_qspec_map [input_node ] = share_qparams_with_input_act0_qspec
257
+
258
+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
259
+ input_qspec_map = input_qspec_map ,
260
+ output_qspec = share_qparams_with_input_act0_qspec ,
261
+ _annotated = True ,
262
+ )
263
+
230
264
def annotate_matmul_input1 (node : Node ):
231
265
quantization_config_8a8w = get_8a8w_qnn_ptq_config (
232
266
act_symmetric = True , act_observer = MinMaxObserver
@@ -247,6 +281,12 @@ def annotate_matmul_input1(node: Node):
247
281
]:
248
282
annotate_single_in_single_out (node , quantization_config_8a8w )
249
283
node = node .args [0 ]
284
+ elif node .target == torch .ops .aten .stack .default :
285
+ annotate_stack (node , quantization_config_8a8w )
286
+ node = node .args [0 ]
287
+ elif node .target == torch .ops .aten .flatten .using_ints :
288
+ annotate_single_in_share_out (node , quantization_config_8a8w )
289
+ node = node .args [0 ]
250
290
elif node .target == torch .ops .aten .cat .default :
251
291
annotate_cat (node , quantization_config_8a8w )
252
292
# For v, we tag 8a until conv op.
0 commit comments