Skip to content

Commit 00754a7

Browse files
authored
[RSDK-10284]: MLModelService: don't check for name if there is only 1 expected tensor, return error if expected tensors not present (#394)
1 parent 92c13e2 commit 00754a7

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

src/viam/sdk/services/private/mlmodel_server.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,13 @@ ::grpc::Status MLModelServiceServer::Infer(
3838

3939
const auto md = mlms->metadata({});
4040
MLModelService::named_tensor_views inputs;
41-
for (const auto& input : md.inputs) {
42-
const auto where = request->input_tensors().tensors().find(input.name);
43-
if (where == request->input_tensors().tensors().end()) {
44-
// Ignore any inputs for which we don't have metadata, since
45-
// we can't validate the type info.
46-
//
47-
// TODO: Should this be an error? For now we just don't decode
48-
// those tensors.
49-
continue;
50-
}
51-
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);
41+
42+
// Check if there's only one input tensor and metadata only expects one, too
43+
if (request->input_tensors().tensors().size() == 1 && md.inputs.size() == 1) {
44+
// Special case: just one tensor, add it without name check
45+
const MLModelService::tensor_info input = md.inputs[0];
46+
const auto& tensor_pair = *request->input_tensors().tensors().begin();
47+
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(tensor_pair.second);
5248
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor);
5349
if (tensor_type != input.data_type) {
5450
std::ostringstream message;
@@ -58,7 +54,35 @@ ::grpc::Status MLModelServiceServer::Infer(
5854
<< static_cast<ut>(tensor_type);
5955
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
6056
}
61-
inputs.emplace(std::move(input.name), std::move(tensor));
57+
inputs.emplace(tensor_pair.first, std::move(tensor));
58+
} else {
59+
// Normal case: multiple tensors, do metadata checks
60+
// If there are extra tensors in the inputs that not found in the metadata,
61+
// they will not be passed on to the implementation.
62+
for (const auto& input : md.inputs) {
63+
const auto where = request->input_tensors().tensors().find(input.name);
64+
if (where == request->input_tensors().tensors().end()) {
65+
// if the input vector of the expected name is not found, return an error
66+
std::ostringstream message;
67+
message << "Expected tensor input `" << input.name
68+
<< "` was not found; if you believe you have this tensor under a "
69+
"different name, rename it to the expected tensor name";
70+
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
71+
}
72+
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);
73+
const auto tensor_type =
74+
MLModelService::tensor_info::tensor_views_to_data_type(tensor);
75+
if (tensor_type != input.data_type) {
76+
std::ostringstream message;
77+
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
78+
message << "Tensor input `" << input.name
79+
<< "` was the wrong type; expected type "
80+
<< static_cast<ut>(input.data_type) << " but got type "
81+
<< static_cast<ut>(tensor_type);
82+
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
83+
}
84+
inputs.emplace(std::move(input.name), std::move(tensor));
85+
}
6286
}
6387

6488
const auto outputs = mlms->infer(inputs, helper.getExtra());

0 commit comments

Comments
 (0)