@@ -38,17 +38,13 @@ ::grpc::Status MLModelServiceServer::Infer(
38
38
39
39
const auto md = mlms->metadata ({});
40
40
MLModelService::named_tensor_views inputs;
41
- for (const auto & input : md.inputs ) {
42
- const auto where = request->input_tensors ().tensors ().find (input.name );
43
- if (where == request->input_tensors ().tensors ().end ()) {
44
- // Ignore any inputs for which we don't have metadata, since
45
- // we can't validate the type info.
46
- //
47
- // TODO: Should this be an error? For now we just don't decode
48
- // those tensors.
49
- continue ;
50
- }
51
- auto tensor = mlmodel::make_sdk_tensor_from_api_tensor (where->second );
41
+
42
+ // Check if there's only one input tensor and metadata only expects one, too
43
+ if (request->input_tensors ().tensors ().size () == 1 && md.inputs .size () == 1 ) {
44
+ // Special case: just one tensor, add it without name check
45
+ const MLModelService::tensor_info input = md.inputs [0 ];
46
+ const auto & tensor_pair = *request->input_tensors ().tensors ().begin ();
47
+ auto tensor = mlmodel::make_sdk_tensor_from_api_tensor (tensor_pair.second );
52
48
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type (tensor);
53
49
if (tensor_type != input.data_type ) {
54
50
std::ostringstream message;
@@ -58,7 +54,35 @@ ::grpc::Status MLModelServiceServer::Infer(
58
54
<< static_cast <ut>(tensor_type);
59
55
return helper.fail (::grpc::INVALID_ARGUMENT, message.str ().c_str ());
60
56
}
61
- inputs.emplace (std::move (input.name ), std::move (tensor));
57
+ inputs.emplace (tensor_pair.first , std::move (tensor));
58
+ } else {
59
+ // Normal case: multiple tensors, do metadata checks
60
+ // If there are extra tensors in the inputs that not found in the metadata,
61
+ // they will not be passed on to the implementation.
62
+ for (const auto & input : md.inputs ) {
63
+ const auto where = request->input_tensors ().tensors ().find (input.name );
64
+ if (where == request->input_tensors ().tensors ().end ()) {
65
+ // if the input vector of the expected name is not found, return an error
66
+ std::ostringstream message;
67
+ message << " Expected tensor input `" << input.name
68
+ << " ` was not found; if you believe you have this tensor under a "
69
+ " different name, rename it to the expected tensor name" ;
70
+ return helper.fail (::grpc::INVALID_ARGUMENT, message.str ().c_str ());
71
+ }
72
+ auto tensor = mlmodel::make_sdk_tensor_from_api_tensor (where->second );
73
+ const auto tensor_type =
74
+ MLModelService::tensor_info::tensor_views_to_data_type (tensor);
75
+ if (tensor_type != input.data_type ) {
76
+ std::ostringstream message;
77
+ using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
78
+ message << " Tensor input `" << input.name
79
+ << " ` was the wrong type; expected type "
80
+ << static_cast <ut>(input.data_type ) << " but got type "
81
+ << static_cast <ut>(tensor_type);
82
+ return helper.fail (::grpc::INVALID_ARGUMENT, message.str ().c_str ());
83
+ }
84
+ inputs.emplace (std::move (input.name ), std::move (tensor));
85
+ }
62
86
}
63
87
64
88
const auto outputs = mlms->infer (inputs, helper.getExtra ());
0 commit comments