1
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
2
// Licensed under the MIT License.
3
+
4
+ #include < memory>
5
+ #include < string>
6
+ #include < utility>
7
+ #include < vector>
8
+
3
9
#include " core/providers/qnn/builder/opbuilder/base_op_builder.h"
4
10
#include " core/providers/qnn/builder/op_builder_factory.h"
5
11
#include " core/providers/qnn/builder/qnn_utils.h"
12
+
6
13
namespace onnxruntime {
7
14
namespace qnn {
15
+
8
16
const int TOPK_MIN_INPUT = 2 ;
9
17
const int TOPK_MAX_INPUT = 2 ;
18
+
10
19
class TopKOpBuilder : public BaseOpBuilder {
11
20
public:
12
21
TopKOpBuilder () : BaseOpBuilder(" TopKOpBuilder" ) {}
@@ -41,8 +50,11 @@ class TopKOpBuilder : public BaseOpBuilder {
41
50
42
51
Status TopKOpBuilder::ExplictOpCheck (QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
43
52
size_t input_count = node_unit.Inputs ().size ();
53
+ size_t output_count = node_unit.Outputs ().size ();
44
54
ORT_RETURN_IF_NOT (input_count >= TOPK_MIN_INPUT && input_count <= TOPK_MAX_INPUT,
45
55
" For ONNX TopK operation the expected number of inputs is 2." );
56
+ ORT_RETURN_IF_NOT (output_count == 2 , " QNN TopK expects exactly 2 outputs." );
57
+
46
58
// Skip the first input. The second input needs to be an initializer.
47
59
const auto & input_1 = node_unit.Inputs ()[1 ].node_arg .Name ();
48
60
if (!qnn_model_wrapper.IsConstantInput (input_1)) {
@@ -57,14 +69,6 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N
57
69
if (0 == largest) {
58
70
return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " QNN TopK output is always largest values" );
59
71
}
60
- auto & input_0 = node_unit.Inputs ()[0 ];
61
- std::vector<uint32_t > input_shape;
62
- ORT_RETURN_IF_NOT (qnn_model_wrapper.GetOnnxShape (input_0.node_arg , input_shape), " Cannot get shape" );
63
- auto rank = input_shape.size ();
64
- auto axis = node_helper.Get (" axis" , -1 );
65
-
66
- ORT_RETURN_IF_NOT (axis == -1 || axis == static_cast <int32_t >(rank - 1 ),
67
- " QNN TopK's axis is always the last dimension" );
68
72
69
73
return Status::OK ();
70
74
}
@@ -81,6 +85,40 @@ Status TopKOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
81
85
const auto & inputs = node_unit.Inputs ();
82
86
ORT_RETURN_IF_ERROR (ProcessInput (qnn_model_wrapper, inputs[0 ], logger, input_names));
83
87
88
+ // HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
89
+ TensorInfo input_info = {};
90
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (node_unit.Inputs ()[0 ], input_info));
91
+
92
+ size_t input_rank = input_info.shape .size ();
93
+ int32_t axis = NodeAttrHelper (node_unit).Get (" axis" , -1 );
94
+ if (axis == -1 || axis == static_cast <int32_t >(input_rank - 1 )) {
95
+ return Status::OK ();
96
+ }
97
+
98
+ // Add Transpose to permute axis to the last.
99
+ std::string transpose_output_name = input_names[0 ] + " _ort_qnn_ep_transpose" ;
100
+ std::vector<uint32_t > transpose_perm;
101
+ ORT_RETURN_IF_ERROR (utils::GetPermToLastAxis (static_cast <uint32_t >(axis),
102
+ static_cast <uint32_t >(input_rank),
103
+ transpose_perm));
104
+
105
+ std::vector<uint32_t > transpose_output_shape = input_info.shape ;
106
+ transpose_output_shape[input_rank - 1 ] = input_info.shape [axis];
107
+ transpose_output_shape[axis] = input_info.shape [input_rank - 1 ];
108
+
109
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
110
+ input_names[0 ],
111
+ transpose_output_name,
112
+ input_info.shape ,
113
+ transpose_perm,
114
+ transpose_output_shape,
115
+ input_info.qnn_data_type ,
116
+ input_info.quant_param ,
117
+ do_op_validation,
118
+ false ,
119
+ false ));
120
+ input_names[0 ] = transpose_output_name;
121
+
84
122
return Status::OK ();
85
123
}
86
124
@@ -108,9 +146,125 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
108
146
std::string k_param_name = k_param.GetParamTensorName ();
109
147
qnn_model_wrapper.AddParamWrapper (std::move (k_param));
110
148
std::vector<std::string> param_tensor_names{k_param_name};
111
- ORT_RETURN_IF_ERROR (ProcessOutputs (qnn_model_wrapper, node_unit, std::move (input_names),
112
- std::move (param_tensor_names), logger, do_op_validation,
113
- GetQnnOpType (node_unit.OpType ())));
149
+
150
+ // HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
151
+ TensorInfo input_info = {};
152
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (node_unit.Inputs ()[0 ], input_info));
153
+
154
+ size_t input_rank = input_info.shape .size ();
155
+ int32_t axis = NodeAttrHelper (node_unit).Get (" axis" , -1 );
156
+ if (axis == -1 || axis == static_cast <int32_t >(input_rank - 1 )) {
157
+ ORT_RETURN_IF_ERROR (ProcessOutputs (qnn_model_wrapper,
158
+ node_unit,
159
+ std::move (input_names),
160
+ std::move (param_tensor_names),
161
+ logger,
162
+ do_op_validation,
163
+ GetQnnOpType (node_unit.OpType ())));
164
+ return Status::OK ();
165
+ }
166
+
167
+ const auto & outputs = node_unit.Outputs ();
168
+ std::vector<std::string> transpose_input_names;
169
+ std::vector<std::vector<std::uint32_t >> transpose_input_shapes;
170
+
171
+ // Add TopK outputs.
172
+ for (size_t output_idx = 0 ; output_idx < 2 ; ++output_idx) {
173
+ const auto & output = outputs[output_idx];
174
+
175
+ // Since user may not be aware of the additional Transpose, the original output name of TopK node must be used by
176
+ // the additional Transpose node which has the same output as original TopK node.
177
+ const std::string& output_name = output.node_arg .Name ();
178
+ std::string transpose_input_name = output_name + " _ort_qnn_ep_transpose" ;
179
+ transpose_input_names.push_back (std::move (transpose_input_name));
180
+
181
+ // Since the input of TopK node is permuted, its output shape must be manually calculated.
182
+ TensorInfo output_info = {};
183
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (output, output_info));
184
+ size_t output_rank = output_info.shape .size ();
185
+
186
+ std::vector<uint32_t > transpose_input_shape = output_info.shape ;
187
+ transpose_input_shape[output_rank - 1 ] = output_info.shape [axis];
188
+ transpose_input_shape[axis] = output_info.shape [output_rank - 1 ];
189
+ transpose_input_shapes.push_back (std::move (transpose_input_shape));
190
+
191
+ QnnTensorWrapper output_tensorwrapper (transpose_input_names[output_idx],
192
+ QNN_TENSOR_TYPE_NATIVE,
193
+ output_info.qnn_data_type ,
194
+ output_info.quant_param .Copy (),
195
+ std::vector<uint32_t >(transpose_input_shapes[output_idx]));
196
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)), " Failed to add tensor." );
197
+ }
198
+
199
+ // Add TopK node.
200
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (utils::GetNodeName (node_unit),
201
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
202
+ GetQnnOpType (node_unit.OpType ()),
203
+ std::move (input_names),
204
+ std::vector<std::string>(transpose_input_names),
205
+ std::move (param_tensor_names)),
206
+ " Failed to add node." );
207
+
208
+ // Add Transpose nodes for each output to permute back.
209
+ for (size_t output_idx = 0 ; output_idx < 2 ; ++output_idx) {
210
+ const auto & output = outputs[output_idx];
211
+ const std::string& output_name = output.node_arg .Name ();
212
+
213
+ TensorInfo output_info = {};
214
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (output, output_info));
215
+ size_t output_rank = output_info.shape .size ();
216
+
217
+ std::vector<uint32_t > transpose_perm;
218
+ ORT_RETURN_IF_ERROR (utils::GetPermToLastAxis (static_cast <uint32_t >(axis),
219
+ static_cast <uint32_t >(output_rank),
220
+ transpose_perm));
221
+
222
+ std::string transpose_output_name = output_name;
223
+ bool is_graph_output = qnn_model_wrapper.IsGraphOutput (output_name);
224
+
225
+ // TopK's second output is indices which could be INT64 dtype, and QnnTensorWrapper directly changes the dtype to
226
+ // INT32 during the wrapper construction. Nevertheless, if this output happens to be graph output, an additional
227
+ // Cast must be added to cast dtype from INT32 back to INT64.
228
+ bool is_cast_required = output_idx == 1 && output_info.qnn_data_type == QNN_DATATYPE_INT_64 && is_graph_output;
229
+ std::string cast_input_name = " " ;
230
+ if (is_cast_required) {
231
+ cast_input_name = transpose_output_name + " _ort_qnn_ep_cast" ;
232
+ // For the same reason described above, the original output name is now used by this Cast.
233
+ transpose_output_name = cast_input_name;
234
+ // Since additional Cast is added, below Transpose is no longer graph output.
235
+ is_graph_output = false ;
236
+ }
237
+
238
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
239
+ transpose_input_names[output_idx],
240
+ transpose_output_name,
241
+ transpose_input_shapes[output_idx],
242
+ transpose_perm,
243
+ output_info.shape ,
244
+ output_info.qnn_data_type ,
245
+ output_info.quant_param ,
246
+ do_op_validation,
247
+ false ,
248
+ is_graph_output));
249
+
250
+ if (is_cast_required) {
251
+ QnnTensorWrapper cast_output_tensorwrapper (output_name,
252
+ QNN_TENSOR_TYPE_APP_READ,
253
+ output_info.qnn_data_type ,
254
+ output_info.quant_param .Copy (),
255
+ std::vector<uint32_t >(output_info.shape ));
256
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (cast_output_tensorwrapper)),
257
+ " Failed to add tensor." );
258
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (cast_input_name,
259
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
260
+ " Cast" ,
261
+ {cast_input_name},
262
+ {output_name},
263
+ {}),
264
+ " Failed to add node" );
265
+ }
266
+ }
267
+
114
268
return Status::OK ();
115
269
}
116
270
0 commit comments