Skip to content

Commit f390eb5

Browse files
[QNN-EP] Support non-last axis TopK. (microsoft#24881)
### Description In TopK op builder, add Transpose around TopK to permute the axis to the last before and permute back after. Additionally, since TopK's second output is indices which may have INT64 dtype, add Cast to cast transformed INT32 back to INT64 if is graph output. ### Motivation and Context QNN only accepts TopK on the last axis but ONNX/ORT's TopK has axis attribute. Complement TopK op builder to avoid falling back to CPU for non-last axis TopK.
1 parent aa64037 commit f390eb5

File tree

6 files changed

+230
-44
lines changed

6 files changed

+230
-44
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,6 @@ std::vector<uint32_t> FlattenShapeFromAxis(const std::vector<uint32_t>& input_sh
5656
return output_shape;
5757
}
5858

59-
std::vector<uint32_t> GetTransposePermToUseLastAxis(uint32_t input_rank, uint32_t axis) {
60-
assert(axis < input_rank);
61-
std::vector<uint32_t> transpose_perm;
62-
transpose_perm.reserve(input_rank);
63-
64-
for (uint32_t dim = 0; dim < input_rank; dim++) {
65-
transpose_perm.push_back(dim);
66-
}
67-
68-
// Swap axis dim with last dim.
69-
transpose_perm[axis] = input_rank - 1;
70-
transpose_perm[input_rank - 1] = axis;
71-
72-
return transpose_perm;
73-
}
74-
7559
Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
7660
const NodeUnit& node_unit,
7761
const logging::Logger& logger,
@@ -131,8 +115,10 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
131115
QNN EP is able to support arbitrary axis attribute by wrapping transposes around the operator.
132116
*/
133117
std::string transpose_output_name = input_name + "_ort_qnn_ep_transpose";
134-
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(input_rank),
135-
static_cast<uint32_t>(axis));
118+
std::vector<uint32_t> transpose_perm;
119+
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
120+
static_cast<uint32_t>(input_rank),
121+
transpose_perm));
136122

137123
std::vector<uint32_t> transpose_output_shape = input_info.shape;
138124
transpose_output_shape[input_rank - 1] = input_info.shape[axis];
@@ -243,8 +229,10 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_
243229
"Failed to add node.");
244230

245231
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);
246-
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(output_rank),
247-
static_cast<uint32_t>(axis));
232+
std::vector<uint32_t> transpose_perm;
233+
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
234+
static_cast<uint32_t>(output_rank),
235+
transpose_perm));
248236

249237
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
250238
transpose_input_name,

onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc

Lines changed: 165 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
3+
4+
#include <memory>
5+
#include <string>
6+
#include <utility>
7+
#include <vector>
8+
39
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
410
#include "core/providers/qnn/builder/op_builder_factory.h"
511
#include "core/providers/qnn/builder/qnn_utils.h"
12+
613
namespace onnxruntime {
714
namespace qnn {
15+
816
const int TOPK_MIN_INPUT = 2;
917
const int TOPK_MAX_INPUT = 2;
18+
1019
class TopKOpBuilder : public BaseOpBuilder {
1120
public:
1221
TopKOpBuilder() : BaseOpBuilder("TopKOpBuilder") {}
@@ -41,8 +50,11 @@ class TopKOpBuilder : public BaseOpBuilder {
4150

4251
Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
4352
size_t input_count = node_unit.Inputs().size();
53+
size_t output_count = node_unit.Outputs().size();
4454
ORT_RETURN_IF_NOT(input_count >= TOPK_MIN_INPUT && input_count <= TOPK_MAX_INPUT,
4555
"For ONNX TopK operation the expected number of inputs is 2.");
56+
ORT_RETURN_IF_NOT(output_count == 2, "QNN TopK expects exactly 2 outputs.");
57+
4658
// Skip the first input. The second input needs to be an initializer.
4759
const auto& input_1 = node_unit.Inputs()[1].node_arg.Name();
4860
if (!qnn_model_wrapper.IsConstantInput(input_1)) {
@@ -57,14 +69,6 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N
5769
if (0 == largest) {
5870
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN TopK output is always largest values");
5971
}
60-
auto& input_0 = node_unit.Inputs()[0];
61-
std::vector<uint32_t> input_shape;
62-
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), "Cannot get shape");
63-
auto rank = input_shape.size();
64-
auto axis = node_helper.Get("axis", -1);
65-
66-
ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast<int32_t>(rank - 1),
67-
"QNN TopK's axis is always the last dimension");
6872

6973
return Status::OK();
7074
}
@@ -81,6 +85,40 @@ Status TopKOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
8185
const auto& inputs = node_unit.Inputs();
8286
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));
8387

88+
// HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
89+
TensorInfo input_info = {};
90+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));
91+
92+
size_t input_rank = input_info.shape.size();
93+
int32_t axis = NodeAttrHelper(node_unit).Get("axis", -1);
94+
if (axis == -1 || axis == static_cast<int32_t>(input_rank - 1)) {
95+
return Status::OK();
96+
}
97+
98+
// Add Transpose to permute axis to the last.
99+
std::string transpose_output_name = input_names[0] + "_ort_qnn_ep_transpose";
100+
std::vector<uint32_t> transpose_perm;
101+
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
102+
static_cast<uint32_t>(input_rank),
103+
transpose_perm));
104+
105+
std::vector<uint32_t> transpose_output_shape = input_info.shape;
106+
transpose_output_shape[input_rank - 1] = input_info.shape[axis];
107+
transpose_output_shape[axis] = input_info.shape[input_rank - 1];
108+
109+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
110+
input_names[0],
111+
transpose_output_name,
112+
input_info.shape,
113+
transpose_perm,
114+
transpose_output_shape,
115+
input_info.qnn_data_type,
116+
input_info.quant_param,
117+
do_op_validation,
118+
false,
119+
false));
120+
input_names[0] = transpose_output_name;
121+
84122
return Status::OK();
85123
}
86124

@@ -108,9 +146,125 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
108146
std::string k_param_name = k_param.GetParamTensorName();
109147
qnn_model_wrapper.AddParamWrapper(std::move(k_param));
110148
std::vector<std::string> param_tensor_names{k_param_name};
111-
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names),
112-
std::move(param_tensor_names), logger, do_op_validation,
113-
GetQnnOpType(node_unit.OpType())));
149+
150+
// HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
151+
TensorInfo input_info = {};
152+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));
153+
154+
size_t input_rank = input_info.shape.size();
155+
int32_t axis = NodeAttrHelper(node_unit).Get("axis", -1);
156+
if (axis == -1 || axis == static_cast<int32_t>(input_rank - 1)) {
157+
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper,
158+
node_unit,
159+
std::move(input_names),
160+
std::move(param_tensor_names),
161+
logger,
162+
do_op_validation,
163+
GetQnnOpType(node_unit.OpType())));
164+
return Status::OK();
165+
}
166+
167+
const auto& outputs = node_unit.Outputs();
168+
std::vector<std::string> transpose_input_names;
169+
std::vector<std::vector<std::uint32_t>> transpose_input_shapes;
170+
171+
// Add TopK outputs.
172+
for (size_t output_idx = 0; output_idx < 2; ++output_idx) {
173+
const auto& output = outputs[output_idx];
174+
175+
// Since user may not be aware of the additional Transpose, the original output name of TopK node must be used by
176+
// the additional Transpose node which has the same output as original TopK node.
177+
const std::string& output_name = output.node_arg.Name();
178+
std::string transpose_input_name = output_name + "_ort_qnn_ep_transpose";
179+
transpose_input_names.push_back(std::move(transpose_input_name));
180+
181+
// Since the input of TopK node is permuted, its output shape must be manually calculated.
182+
TensorInfo output_info = {};
183+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, output_info));
184+
size_t output_rank = output_info.shape.size();
185+
186+
std::vector<uint32_t> transpose_input_shape = output_info.shape;
187+
transpose_input_shape[output_rank - 1] = output_info.shape[axis];
188+
transpose_input_shape[axis] = output_info.shape[output_rank - 1];
189+
transpose_input_shapes.push_back(std::move(transpose_input_shape));
190+
191+
QnnTensorWrapper output_tensorwrapper(transpose_input_names[output_idx],
192+
QNN_TENSOR_TYPE_NATIVE,
193+
output_info.qnn_data_type,
194+
output_info.quant_param.Copy(),
195+
std::vector<uint32_t>(transpose_input_shapes[output_idx]));
196+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
197+
}
198+
199+
// Add TopK node.
200+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
201+
QNN_OP_PACKAGE_NAME_QTI_AISW,
202+
GetQnnOpType(node_unit.OpType()),
203+
std::move(input_names),
204+
std::vector<std::string>(transpose_input_names),
205+
std::move(param_tensor_names)),
206+
"Failed to add node.");
207+
208+
// Add Transpose nodes for each output to permute back.
209+
for (size_t output_idx = 0; output_idx < 2; ++output_idx) {
210+
const auto& output = outputs[output_idx];
211+
const std::string& output_name = output.node_arg.Name();
212+
213+
TensorInfo output_info = {};
214+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, output_info));
215+
size_t output_rank = output_info.shape.size();
216+
217+
std::vector<uint32_t> transpose_perm;
218+
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
219+
static_cast<uint32_t>(output_rank),
220+
transpose_perm));
221+
222+
std::string transpose_output_name = output_name;
223+
bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name);
224+
225+
// TopK's second output is indices which could be INT64 dtype, and QnnTensorWrapper directly changes the dtype to
226+
// INT32 during the wrapper construction. Nevertheless, if this output happens to be graph output, an additional
227+
// Cast must be added to cast dtype from INT32 back to INT64.
228+
bool is_cast_required = output_idx == 1 && output_info.qnn_data_type == QNN_DATATYPE_INT_64 && is_graph_output;
229+
std::string cast_input_name = "";
230+
if (is_cast_required) {
231+
cast_input_name = transpose_output_name + "_ort_qnn_ep_cast";
232+
// For the same reason described above, the original output name is now used by this Cast.
233+
transpose_output_name = cast_input_name;
234+
// Since additional Cast is added, below Transpose is no longer graph output.
235+
is_graph_output = false;
236+
}
237+
238+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
239+
transpose_input_names[output_idx],
240+
transpose_output_name,
241+
transpose_input_shapes[output_idx],
242+
transpose_perm,
243+
output_info.shape,
244+
output_info.qnn_data_type,
245+
output_info.quant_param,
246+
do_op_validation,
247+
false,
248+
is_graph_output));
249+
250+
if (is_cast_required) {
251+
QnnTensorWrapper cast_output_tensorwrapper(output_name,
252+
QNN_TENSOR_TYPE_APP_READ,
253+
output_info.qnn_data_type,
254+
output_info.quant_param.Copy(),
255+
std::vector<uint32_t>(output_info.shape));
256+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output_tensorwrapper)),
257+
"Failed to add tensor.");
258+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_input_name,
259+
QNN_OP_PACKAGE_NAME_QTI_AISW,
260+
"Cast",
261+
{cast_input_name},
262+
{output_name},
263+
{}),
264+
"Failed to add node");
265+
}
266+
}
267+
114268
return Status::OK();
115269
}
116270

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -601,13 +601,14 @@ Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index,
601601
ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
602602
const static std::string qnn_node_type = "Transpose";
603603

604-
CreateQnnNode(output_name,
605-
QNN_OP_PACKAGE_NAME_QTI_AISW,
606-
qnn_node_type,
607-
{input_name},
608-
{output_name},
609-
{param_tensor_name},
610-
do_op_validation);
604+
ORT_RETURN_IF_NOT(CreateQnnNode(output_name,
605+
QNN_OP_PACKAGE_NAME_QTI_AISW,
606+
qnn_node_type,
607+
{input_name},
608+
{output_name},
609+
{param_tensor_name},
610+
do_op_validation),
611+
"QNN EP: Failed to create manually inserted Qnn Transpose node.");
611612

612613
return Status::OK();
613614
}

onnxruntime/core/providers/qnn/builder/qnn_utils.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,21 @@ Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper,
12991299
return Status::OK();
13001300
}
13011301

1302+
Status GetPermToLastAxis(uint32_t axis, uint32_t rank, std::vector<uint32_t>& perm) {
1303+
ORT_RETURN_IF_NOT(axis < rank, "Expected axis must be smaller than rank: ", axis, " >= ", rank);
1304+
1305+
perm.reserve(rank);
1306+
for (uint32_t dim = 0; dim < rank; ++dim) {
1307+
perm.push_back(dim);
1308+
}
1309+
1310+
// Swap axis with the last one.
1311+
perm[axis] = rank - 1;
1312+
perm[rank - 1] = axis;
1313+
1314+
return Status::OK();
1315+
}
1316+
13021317
} // namespace utils
13031318
} // namespace qnn
13041319
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/qnn_utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,16 @@ Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper,
385385
bool output_symmetric,
386386
bool do_op_validation);
387387

388+
/**
389+
* Get permutation to transpose given axis to the last one.
390+
*
391+
* @param[in] axis the current axis to be transposed
392+
* @param[in] rank the expected rank for permutation
393+
* @param[out] perm the permutation for transpose
394+
* @return execution status of this function
395+
*/
396+
Status GetPermToLastAxis(uint32_t axis, uint32_t rank, std::vector<uint32_t>& perm);
397+
388398
} // namespace utils
389399
} // namespace qnn
390400
} // namespace onnxruntime

onnxruntime/test/providers/qnn/topk_op_test.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
#if !defined(ORT_MINIMAL_BUILD)
55

66
#include <string>
7+
#include <vector>
78

8-
#include "test/providers/qnn/qnn_test_utils.h"
9-
#include "core/graph/node_attr_utils.h"
9+
#include "gtest/gtest.h"
1010

11+
#include "core/graph/node_attr_utils.h"
1112
#include "core/graph/onnx_protobuf.h"
12-
#include "gtest/gtest.h"
13+
#include "test/providers/qnn/qnn_test_utils.h"
1314

1415
namespace onnxruntime {
1516
namespace test {
@@ -63,12 +64,12 @@ TEST_F(QnnCPUBackendTests, TopK_DynamicK_Unsupported) {
6364
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
6465
}
6566

66-
// Test that TopK with an axis attribute that is not the last dimension is not supported by QNN EP.
67-
TEST_F(QnnCPUBackendTests, TopK_NonLastAxis_Unsupported) {
67+
// Test that TopK with an axis attribute that is not the last dimension.
68+
TEST_F(QnnCPUBackendTests, TopK_NonLastAxis) {
6869
RunTopKTestOnCPU<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
6970
TestInputDef<int64_t>({1}, true /* is_initializer */, {2}),
7071
{utils::MakeAttribute("axis", static_cast<int64_t>(1))},
71-
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
72+
ExpectedEPNodeAssignment::All);
7273
}
7374

7475
// Test that TopK that returns the top k minimum values is not supported by QNN EP.
@@ -165,6 +166,14 @@ TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) {
165166
ExpectedEPNodeAssignment::All);
166167
}
167168

169+
// Test 8-bit QDQ TopK on HTP backend: non-last axis
170+
TEST_F(QnnHTPBackendTests, TopK_U8_NonLastAxis) {
171+
RunQDQTopKTestOnHTP<uint8_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
172+
TestInputDef<int64_t>({1}, true /* is_initializer */, {2}),
173+
{utils::MakeAttribute("axis", static_cast<int64_t>(1))}, // Attributes
174+
ExpectedEPNodeAssignment::All);
175+
}
176+
168177
// Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis
169178
TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U16_LastAxis) {
170179
RunQDQTopKTestOnHTP<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)),
@@ -174,6 +183,15 @@ TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U16_LastAxis) {
174183
21); // opset
175184
}
176185

186+
// Test 16-bit QDQ TopK on HTP backend: non-last axis
187+
TEST_F(QnnHTPBackendTests, TopK_U16_NonLastAxis) {
188+
RunQDQTopKTestOnHTP<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)),
189+
TestInputDef<int64_t>({1}, true /* is_initializer */, {2}),
190+
{utils::MakeAttribute("axis", static_cast<int64_t>(1))}, // Attributes
191+
ExpectedEPNodeAssignment::All,
192+
21); // opset
193+
}
194+
177195
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
178196
} // namespace test
179197
} // namespace onnxruntime

0 commit comments

Comments
 (0)