Skip to content

[QNN-EP] Support non-last axis TopK. #24881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,6 @@ std::vector<uint32_t> FlattenShapeFromAxis(const std::vector<uint32_t>& input_sh
return output_shape;
}

std::vector<uint32_t> GetTransposePermToUseLastAxis(uint32_t input_rank, uint32_t axis) {
assert(axis < input_rank);
std::vector<uint32_t> transpose_perm;
transpose_perm.reserve(input_rank);

for (uint32_t dim = 0; dim < input_rank; dim++) {
transpose_perm.push_back(dim);
}

// Swap axis dim with last dim.
transpose_perm[axis] = input_rank - 1;
transpose_perm[input_rank - 1] = axis;

return transpose_perm;
}

Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
Expand Down Expand Up @@ -131,8 +115,10 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
QNN EP is able to support arbitrary axis attribute by wrapping transposes around the operator.
*/
std::string transpose_output_name = input_name + "_ort_qnn_ep_transpose";
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(input_rank),
static_cast<uint32_t>(axis));
std::vector<uint32_t> transpose_perm;
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
static_cast<uint32_t>(input_rank),
transpose_perm));

std::vector<uint32_t> transpose_output_shape = input_info.shape;
transpose_output_shape[input_rank - 1] = input_info.shape[axis];
Expand Down Expand Up @@ -243,8 +229,10 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_
"Failed to add node.");

const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(output_rank),
static_cast<uint32_t>(axis));
std::vector<uint32_t> transpose_perm;
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
static_cast<uint32_t>(output_rank),
transpose_perm));

ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
transpose_input_name,
Expand Down
176 changes: 165 additions & 11 deletions onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/qnn_utils.h"

namespace onnxruntime {
namespace qnn {

const int TOPK_MIN_INPUT = 2;
const int TOPK_MAX_INPUT = 2;

class TopKOpBuilder : public BaseOpBuilder {
public:
TopKOpBuilder() : BaseOpBuilder("TopKOpBuilder") {}
Expand Down Expand Up @@ -41,8 +50,11 @@ class TopKOpBuilder : public BaseOpBuilder {

Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
size_t input_count = node_unit.Inputs().size();
size_t output_count = node_unit.Outputs().size();
ORT_RETURN_IF_NOT(input_count >= TOPK_MIN_INPUT && input_count <= TOPK_MAX_INPUT,
"For ONNX TopK operation the expected number of inputs is 2.");
ORT_RETURN_IF_NOT(output_count == 2, "QNN TopK expects exactly 2 outputs.");

// Skip the first input. The second input needs to be an initializer.
const auto& input_1 = node_unit.Inputs()[1].node_arg.Name();
if (!qnn_model_wrapper.IsConstantInput(input_1)) {
Expand All @@ -57,14 +69,6 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N
if (0 == largest) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN TopK output is always largest values");
}
auto& input_0 = node_unit.Inputs()[0];
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), "Cannot get shape");
auto rank = input_shape.size();
auto axis = node_helper.Get("axis", -1);

ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast<int32_t>(rank - 1),
"QNN TopK's axis is always the last dimension");

return Status::OK();
}
Expand All @@ -81,6 +85,40 @@ Status TopKOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));

// HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
TensorInfo input_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));

size_t input_rank = input_info.shape.size();
int32_t axis = NodeAttrHelper(node_unit).Get("axis", -1);
if (axis == -1 || axis == static_cast<int32_t>(input_rank - 1)) {
return Status::OK();
}

// Add Transpose to permute axis to the last.
std::string transpose_output_name = input_names[0] + "_ort_qnn_ep_transpose";
std::vector<uint32_t> transpose_perm;
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
static_cast<uint32_t>(input_rank),
transpose_perm));

std::vector<uint32_t> transpose_output_shape = input_info.shape;
transpose_output_shape[input_rank - 1] = input_info.shape[axis];
transpose_output_shape[axis] = input_info.shape[input_rank - 1];

ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
input_names[0],
transpose_output_name,
input_info.shape,
transpose_perm,
transpose_output_shape,
input_info.qnn_data_type,
input_info.quant_param,
do_op_validation,
false,
false));
input_names[0] = transpose_output_name;

return Status::OK();
}

Expand Down Expand Up @@ -108,9 +146,125 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
std::string k_param_name = k_param.GetParamTensorName();
qnn_model_wrapper.AddParamWrapper(std::move(k_param));
std::vector<std::string> param_tensor_names{k_param_name};
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names),
std::move(param_tensor_names), logger, do_op_validation,
GetQnnOpType(node_unit.OpType())));

// HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
TensorInfo input_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));

size_t input_rank = input_info.shape.size();
int32_t axis = NodeAttrHelper(node_unit).Get("axis", -1);
if (axis == -1 || axis == static_cast<int32_t>(input_rank - 1)) {
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper,
node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger,
do_op_validation,
GetQnnOpType(node_unit.OpType())));
return Status::OK();
}

const auto& outputs = node_unit.Outputs();
std::vector<std::string> transpose_input_names;
std::vector<std::vector<std::uint32_t>> transpose_input_shapes;

// Add TopK outputs.
for (size_t output_idx = 0; output_idx < 2; ++output_idx) {
const auto& output = outputs[output_idx];

// Since user may not be aware of the additional Transpose, the original output name of TopK node must be used by
// the additional Transpose node which has the same output as original TopK node.
const std::string& output_name = output.node_arg.Name();
std::string transpose_input_name = output_name + "_ort_qnn_ep_transpose";
transpose_input_names.push_back(std::move(transpose_input_name));

// Since the input of TopK node is permuted, its output shape must be manually calculated.
TensorInfo output_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, output_info));
size_t output_rank = output_info.shape.size();

std::vector<uint32_t> transpose_input_shape = output_info.shape;
transpose_input_shape[output_rank - 1] = output_info.shape[axis];
transpose_input_shape[axis] = output_info.shape[output_rank - 1];
transpose_input_shapes.push_back(std::move(transpose_input_shape));

QnnTensorWrapper output_tensorwrapper(transpose_input_names[output_idx],
QNN_TENSOR_TYPE_NATIVE,
output_info.qnn_data_type,
output_info.quant_param.Copy(),
std::vector<uint32_t>(transpose_input_shapes[output_idx]));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
}

// Add TopK node.
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
QNN_OP_PACKAGE_NAME_QTI_AISW,
GetQnnOpType(node_unit.OpType()),
std::move(input_names),
std::vector<std::string>(transpose_input_names),
std::move(param_tensor_names)),
"Failed to add node.");

// Add Transpose nodes for each output to permute back.
for (size_t output_idx = 0; output_idx < 2; ++output_idx) {
const auto& output = outputs[output_idx];
const std::string& output_name = output.node_arg.Name();

TensorInfo output_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, output_info));
size_t output_rank = output_info.shape.size();

std::vector<uint32_t> transpose_perm;
ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast<uint32_t>(axis),
static_cast<uint32_t>(output_rank),
transpose_perm));

std::string transpose_output_name = output_name;
bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name);

// TopK's second output is indices which could be INT64 dtype, and QnnTensorWrapper directly changes the dtype to
// INT32 during the wrapper construction. Nevertheless, if this output happens to be graph output, an additional
// Cast must be added to cast dtype from INT32 back to INT64.
bool is_cast_required = output_idx == 1 && output_info.qnn_data_type == QNN_DATATYPE_INT_64 && is_graph_output;
std::string cast_input_name = "";
if (is_cast_required) {
cast_input_name = transpose_output_name + "_ort_qnn_ep_cast";
// For the same reason described above, the original output name is now used by this Cast.
transpose_output_name = cast_input_name;
// Since additional Cast is added, below Transpose is no longer graph output.
is_graph_output = false;
}

ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
transpose_input_names[output_idx],
transpose_output_name,
transpose_input_shapes[output_idx],
transpose_perm,
output_info.shape,
output_info.qnn_data_type,
output_info.quant_param,
do_op_validation,
false,
is_graph_output));

if (is_cast_required) {
QnnTensorWrapper cast_output_tensorwrapper(output_name,
QNN_TENSOR_TYPE_APP_READ,
output_info.qnn_data_type,
output_info.quant_param.Copy(),
std::vector<uint32_t>(output_info.shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output_tensorwrapper)),
"Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_input_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{cast_input_name},
{output_name},
{}),
"Failed to add node");
}
}

return Status::OK();
}

Expand Down
15 changes: 8 additions & 7 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,13 +601,14 @@ Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index,
ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
const static std::string qnn_node_type = "Transpose";

CreateQnnNode(output_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
qnn_node_type,
{input_name},
{output_name},
{param_tensor_name},
do_op_validation);
ORT_RETURN_IF_NOT(CreateQnnNode(output_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
qnn_node_type,
{input_name},
{output_name},
{param_tensor_name},
do_op_validation),
"QNN EP: Failed to create manually inserted Qnn Transpose node.");

return Status::OK();
}
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,21 @@ Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

Status GetPermToLastAxis(uint32_t axis, uint32_t rank, std::vector<uint32_t>& perm) {
ORT_RETURN_IF_NOT(axis < rank, "Expected axis must be smaller than rank: ", axis, " >= ", rank);

perm.reserve(rank);
for (uint32_t dim = 0; dim < rank; ++dim) {
perm.push_back(dim);
}

// Swap axis with the last one.
perm[axis] = rank - 1;
perm[rank - 1] = axis;

return Status::OK();
}

} // namespace utils
} // namespace qnn
} // namespace onnxruntime
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,16 @@ Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper,
bool output_symmetric,
bool do_op_validation);

/**
* Get permutation to transpose given axis to the last one.
*
* @param[in] axis the current axis to be transposed
* @param[in] rank the expected rank for permutation
* @param[out] perm the permutation for transpose
* @return execution status of this function
*/
Status GetPermToLastAxis(uint32_t axis, uint32_t rank, std::vector<uint32_t>& perm);

} // namespace utils
} // namespace qnn
} // namespace onnxruntime
30 changes: 24 additions & 6 deletions onnxruntime/test/providers/qnn/topk_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
#if !defined(ORT_MINIMAL_BUILD)

#include <string>
#include <vector>

#include "test/providers/qnn/qnn_test_utils.h"
#include "core/graph/node_attr_utils.h"
#include "gtest/gtest.h"

#include "core/graph/node_attr_utils.h"
#include "core/graph/onnx_protobuf.h"
#include "gtest/gtest.h"
#include "test/providers/qnn/qnn_test_utils.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -63,12 +64,12 @@ TEST_F(QnnCPUBackendTests, TopK_DynamicK_Unsupported) {
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
}

// Test that TopK with an axis attribute that is not the last dimension is not supported by QNN EP.
TEST_F(QnnCPUBackendTests, TopK_NonLastAxis_Unsupported) {
// Test that TopK with an axis attribute that is not the last dimension.
TEST_F(QnnCPUBackendTests, TopK_NonLastAxis) {
RunTopKTestOnCPU<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
TestInputDef<int64_t>({1}, true /* is_initializer */, {2}),
{utils::MakeAttribute("axis", static_cast<int64_t>(1))},
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
ExpectedEPNodeAssignment::All);
}

// Test that TopK that returns the top k minimum values is not supported by QNN EP.
Expand Down Expand Up @@ -165,6 +166,14 @@ TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) {
ExpectedEPNodeAssignment::All);
}

// Test 8-bit QDQ TopK on HTP backend: non-last axis
TEST_F(QnnHTPBackendTests, TopK_U8_NonLastAxis) {
RunQDQTopKTestOnHTP<uint8_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
TestInputDef<int64_t>({1}, true /* is_initializer */, {2}),
{utils::MakeAttribute("axis", static_cast<int64_t>(1))}, // Attributes
ExpectedEPNodeAssignment::All);
}

// Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis
TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U16_LastAxis) {
RunQDQTopKTestOnHTP<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)),
Expand All @@ -174,6 +183,15 @@ TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U16_LastAxis) {
21); // opset
}

// Test 16-bit QDQ TopK on HTP backend: non-last axis
TEST_F(QnnHTPBackendTests, TopK_U16_NonLastAxis) {
RunQDQTopKTestOnHTP<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)),
TestInputDef<int64_t>({1}, true /* is_initializer */, {2}),
{utils::MakeAttribute("axis", static_cast<int64_t>(1))}, // Attributes
ExpectedEPNodeAssignment::All,
21); // opset
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
} // namespace test
} // namespace onnxruntime
Expand Down
Loading