Skip to content

[MacOS] Add MLProgram Gather op for CoreML EP #24387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/shape_utils.h"
Expand All @@ -18,6 +19,7 @@

bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;
bool SupportsMLProgram() const override { return true; }
};

namespace {
Expand All @@ -28,13 +30,37 @@
} // namespace

Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /*logger*/) const {
auto layer = model_builder.CreateNNLayer(node);
layer->mutable_gather()->set_axis(GetAxisAttribute(node));
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data
*layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); // output
model_builder.AddLayer(std::move(layer));
const logging::Logger& logger) const {
if (model_builder.CreateMLProgram()) {
using CoreML::Specification::MILSpec::Operation;
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "gather");

Check warning on line 36 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc:36: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

std::optional<int32_t> output_datatype;

int32_t input_type;
ORT_RETURN_IF_NOT(GetType(*node.InputDefs()[0], input_type, logger), "Failed to get input type");

if (input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32;
}

const auto axis = GetAxisAttribute(node);
// coreml docs claims validate_indices is optional but in practice it is required
const auto validate_indices = false;
AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); // data
AddOperationInput(*op, "indices", node.InputDefs()[1]->Name()); // indices
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); // axis attr
AddOperationInput(*op, "validate_indices", model_builder.AddScalarConstant(op->type(), "validate_indices", validate_indices));
AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); // output
model_builder.AddOperation(std::move(op));
} else {
auto layer = model_builder.CreateNNLayer(node);
layer->mutable_gather()->set_axis(GetAxisAttribute(node));
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data
*layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); // output
model_builder.AddLayer(std::move(layer));

Check warning on line 62 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc:62: Add #include <utility> for move [build/include_what_you_use] [4]
}
return Status::OK();
}

Expand Down
57 changes: 43 additions & 14 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,36 @@ void CopyRawDataToRepeatedField(const ONNX_NAMESPACE::TensorProto& tensor_proto,
}
}

template <>
void CopyRawDataToRepeatedField<int64_t, int32_t>(const ONNX_NAMESPACE::TensorProto& tensor_proto,
google::protobuf::RepeatedField<int32_t>& repeated_field) {
const auto& raw_data = tensor_proto.raw_data();
const int64_t* data = reinterpret_cast<const int64_t*>(raw_data.data());
const size_t element_count = raw_data.size() / sizeof(int64_t);

// Reserve space to avoid multiple reallocations
repeated_field.Reserve(narrow<int>(element_count));

// Use std::transform with proper iterators
std::transform(data, data + element_count,
google::protobuf::RepeatedFieldBackInserter(&repeated_field),
[](int64_t v) {
return narrow<int32_t>(v);
});
}

void CopyInt64DataToInt32(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue& tensor_value) {
const int num_entries = tensor_proto.int64_data_size();
auto& int32_out = *tensor_value.mutable_ints()->mutable_values();
int32_out.Reserve(num_entries);
for (int i = 0; i < num_entries; ++i) {
int32_out.AddAlreadyReserved(narrow<int32_t>(tensor_proto.int64_data(i)));
}
}

// copy T data from the TensorProto.int32_t field to TensorValue.bytes
template <typename T>
void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) {
void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue& tensor_value) {
const int num_entries = tensor_proto.int32_data_size();
std::string& bytes = *tensor_value.mutable_bytes()->mutable_values();
bytes.resize(num_entries * sizeof(T));
Expand All @@ -87,7 +114,7 @@ void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSp

// copy T data from the TensorProto.uint64_data field to TensorValue.bytes
template <typename T>
void CopyUInt64DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) {
void CopyUInt64DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue& tensor_value) {
const int num_entries = tensor_proto.uint64_data_size();
std::string& bytes = *tensor_value.mutable_bytes()->mutable_values();
bytes.resize(num_entries * sizeof(T));
Expand Down Expand Up @@ -143,18 +170,16 @@ void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_prot
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
// enable when this is proven to not be the case
ORT_THROW(
"INT64 is unexpected as CoreML uses 32-bit int for indices. "
"Most likely an initializer that should have been skipped was not.");
//// from: int64_data/raw, to: longints
// if (has_raw_data) {
// CopyRawDataToRepeatedField<int64_t>(tensor_proto, *tensor_value.mutable_longints()->mutable_values());
// from: int64_data/raw, to: ints (use narrow to convert to int32)
// CoreML tensors have a longints field, but the CoreML op definitions only use int32,
// so we convert any int64 to int32
if (has_raw_data) {
CopyRawDataToRepeatedField<int64_t, int32_t>(tensor_proto, *tensor_value.mutable_ints()->mutable_values());

//} else {
// tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data());
//}
// break;
} else {
CopyInt64DataToInt32(tensor_proto, tensor_value);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
// from: int32_data/raw, to: bytes
Expand Down Expand Up @@ -356,7 +381,11 @@ MILSpec::Value OnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tenso
// populate ValueType with tensor data type, dims and rank
MILSpec::ValueType& value_type = *value.mutable_type();
MILSpec::TensorType& tensor_type = *value_type.mutable_tensortype();
tensor_type.set_datatype(OnnxDataTypeToMILSpec(tensor_proto.data_type()));
MILSpec::DataType data_type = OnnxDataTypeToMILSpec(tensor_proto.data_type());
MILSpec::DataType converted_data_type = data_type == MILSpec::DataType::INT64
? MILSpec::DataType::INT32
: data_type;
tensor_type.set_datatype(converted_data_type);

tensor_type.set_rank(tensor_proto.dims().size());
for (const auto& dim : tensor_proto.dims()) {
Expand Down
Loading