Skip to content

[RSDK-10284]: MLModelService: don't check for name if there is only 1 expected tensor, return error if expected tensors not present #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions src/viam/sdk/services/private/mlmodel_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,13 @@ ::grpc::Status MLModelServiceServer::Infer(

const auto md = mlms->metadata({});
MLModelService::named_tensor_views inputs;
for (const auto& input : md.inputs) {
const auto where = request->input_tensors().tensors().find(input.name);
if (where == request->input_tensors().tensors().end()) {
// Ignore any inputs for which we don't have metadata, since
// we can't validate the type info.
//
// TODO: Should this be an error? For now we just don't decode
// those tensors.
continue;
}
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);

// Check if there's only one input tensor and metadata only expects one, too
if (request->input_tensors().tensors().size() == 1 && md.inputs.size() == 1) {
// Special case: just one tensor, add it without name check
const MLModelService::tensor_info input = md.inputs[0];
const auto& tensor_pair = *request->input_tensors().tensors().begin();
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(tensor_pair.second);
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor);
if (tensor_type != input.data_type) {
std::ostringstream message;
Expand All @@ -58,7 +54,35 @@ ::grpc::Status MLModelServiceServer::Infer(
<< static_cast<ut>(tensor_type);
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
inputs.emplace(std::move(input.name), std::move(tensor));
inputs.emplace(tensor_pair.first, std::move(tensor));
} else {
// Normal case: multiple tensors, do metadata checks
// If there are extra tensors in the inputs that not found in the metadata,
// they will not be passed on to the implementation.
for (const auto& input : md.inputs) {
const auto where = request->input_tensors().tensors().find(input.name);
if (where == request->input_tensors().tensors().end()) {
// if the input vector of the expected name is not found, return an error
std::ostringstream message;
message << "Expected tensor input `" << input.name
<< "` was not found; if you believe you have this tensor under a "
"different name, rename it to the expected tensor name";
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);
const auto tensor_type =
MLModelService::tensor_info::tensor_views_to_data_type(tensor);
if (tensor_type != input.data_type) {
std::ostringstream message;
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
message << "Tensor input `" << input.name
<< "` was the wrong type; expected type "
<< static_cast<ut>(input.data_type) << " but got type "
<< static_cast<ut>(tensor_type);
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
inputs.emplace(std::move(input.name), std::move(tensor));
}
}

const auto outputs = mlms->infer(inputs, helper.getExtra());
Expand Down