Skip to content

Commit fa568e4

Browse files
skottmckaygramalingam
authored andcommitted
Loop type shape inferencing (onnx#1591)
* Add Loop type/shape inferencing. Make Loop spec more consistent. * Remove temporary debug code. * Make work with current Loop spec. * Allow loop carried dependencies to change shape across iterations. * Make iter_num_in type correct. * Check iteration num type is tensor(int64) to match the max iterations input. Update mergeShapesAndTypes to fail if the types don't match.
1 parent 937e64c commit fa568e4

File tree

4 files changed

+208
-10
lines changed

4 files changed

+208
-10
lines changed

Diff for: onnx/defs/controlflow/defs.cc

+113-1
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,117 @@ void IfInferenceFunction(InferenceContext& ctx) {
272272
}
273273
}
274274

275+
void LoopInferenceFunction(InferenceContext& ctx) {
276+
auto num_inputs = ctx.getNumInputs();
277+
auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'
278+
279+
std::vector<const TypeProto*> subgraph_input_types;
280+
281+
std::vector<TypeProto> temporary_type_protos;
282+
temporary_type_protos.reserve(num_inputs - 2);
283+
284+
// create TypeProto to validate iteration number type is the same as the
285+
// optional 'M' input for max iterations.
286+
TypeProto iter_num_type;
287+
iter_num_type.mutable_tensor_type()->set_elem_type(
288+
TensorProto_DataType_INT64);
289+
subgraph_input_types.push_back(&iter_num_type);
290+
291+
// 'cond'
292+
subgraph_input_types.push_back(ctx.getInputType(1));
293+
294+
// loop state value types get propagated to outputs, but shape may change
295+
// across iterations so don't propagate it to the outputs and don't pass it
296+
// into the subgraph inferencing
297+
for (size_t i = 2; i < num_inputs; ++i) {
298+
propagateElemTypeFromInputToOutput(ctx, i, i - 2);
299+
300+
// copy so we can remove the shape before passing to the subgraph
301+
// inferencing
302+
temporary_type_protos.push_back(*ctx.getInputType(i));
303+
auto& input_type = temporary_type_protos.back();
304+
input_type.mutable_tensor_type()->clear_shape();
305+
306+
subgraph_input_types.push_back(&input_type);
307+
}
308+
309+
// Run inferencing on the subgraph
310+
std::vector<const TypeProto*> subgraph_output_types;
311+
312+
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
313+
if (graphInferencer) {
314+
std::vector<const TensorProto*> input_data;
315+
input_data.push_back(nullptr); // iteration number
316+
for (size_t i = 1; i < num_inputs; ++i) {
317+
input_data.push_back(ctx.getInputData(i));
318+
}
319+
320+
subgraph_output_types =
321+
graphInferencer->doInferencing(subgraph_input_types, input_data);
322+
}
323+
324+
// if empty(), assume inferencing was skipped
325+
if (!subgraph_output_types.empty()) {
326+
auto num_outputs = ctx.getNumOutputs();
327+
328+
// subgraph outputs the condition value first but that is only used
329+
// internally and not returned by Loop.
330+
if (subgraph_output_types.size() != num_outputs + 1) {
331+
fail_type_inference(
332+
"Graph attribute inferencing returned type information for ",
333+
subgraph_output_types.size(),
334+
" outputs. Expected ",
335+
num_outputs + 1);
336+
}
337+
338+
// check loop state values match. we should already have type/shape info
339+
for (size_t i = 0; i < num_outputs; ++i) {
340+
auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond'
341+
auto* loop_output_type = ctx.getOutputType(i);
342+
343+
const bool is_loop_state_var = i < num_loop_state_vars;
344+
345+
if (!subgraph_output_type->has_tensor_type()) {
346+
fail_type_inference(
347+
"Loop 'body' subgraph outputs should all be tensors but output ",
348+
i,
349+
" was ",
350+
subgraph_output_type->value_case());
351+
}
352+
353+
// if there's an existing type check it matches. otherwise propagate
354+
propagateElemTypeWithValidation(subgraph_output_type, loop_output_type);
355+
356+
if (is_loop_state_var) {
357+
// shape may change across iterations so ignore.
358+
} else {
359+
// per iteration output. first dimension will be number of iterations
360+
// but we don't know that value yet
361+
TypeProto inferred_type(*subgraph_output_type);
362+
auto* mutable_inferred_tensor_type =
363+
inferred_type.mutable_tensor_type();
364+
auto* mutable_inferred_shape =
365+
mutable_inferred_tensor_type->mutable_shape();
366+
367+
mutable_inferred_shape->clear_dim();
368+
369+
// add empty dimension for number of iterations
370+
mutable_inferred_shape->add_dim();
371+
372+
// add dimensions from subgraph output shape
373+
for (const auto& dim :
374+
subgraph_output_type->tensor_type().shape().dim()) {
375+
(*mutable_inferred_shape->add_dim()) = dim;
376+
}
377+
378+
mergeInShapeInfo(
379+
*mutable_inferred_tensor_type,
380+
*loop_output_type->mutable_tensor_type());
381+
}
382+
}
383+
}
384+
}
385+
275386
ONNX_OPERATOR_SET_SCHEMA(
276387
If,
277388
1,
@@ -459,7 +570,8 @@ ONNX_OPERATOR_SET_SCHEMA(
459570
AttributeProto::GRAPH)
460571
.TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
461572
.TypeConstraint("I", {"int64"}, "Only int64")
462-
.TypeConstraint("B", {"bool"}, "Only bool"));
573+
.TypeConstraint("B", {"bool"}, "Only bool")
574+
.TypeAndShapeInferenceFunction(LoopInferenceFunction));
463575

464576
static const char* scan_ver1_doc = R"DOC(
465577
Scan can be used to iterate over one or more scan_input tensors,

Diff for: onnx/defs/shape_inference.h

+44
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,50 @@ multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
151151
return dim;
152152
}
153153

154+
// propagate the element type from an input type to an output type.
155+
// if an existing output element type exists, validate it matches.
156+
inline void propagateElemTypeWithValidation(
157+
const TypeProto* input_type,
158+
TypeProto* output_type) {
159+
if (nullptr == input_type) {
160+
fail_type_inference("Input type was null");
161+
}
162+
163+
if (input_type->value_case() != TypeProto::kTensorType) {
164+
fail_type_inference(
165+
"Input was expected to have tensor type. Got ",
166+
input_type->value_case());
167+
}
168+
169+
if (input_type->tensor_type().elem_type() == TensorProto::UNDEFINED) {
170+
fail_type_inference("Element type of input was unknown");
171+
}
172+
173+
if (output_type->value_case() == TypeProto::VALUE_NOT_SET) {
174+
output_type->mutable_tensor_type()->set_elem_type(
175+
input_type->tensor_type().elem_type());
176+
} else if (output_type->value_case() == TypeProto::kTensorType) {
177+
if (output_type->tensor_type().has_elem_type()) {
178+
if (input_type->tensor_type().elem_type() !=
179+
output_type->tensor_type().elem_type()) {
180+
fail_type_inference(
181+
"Input element type of ",
182+
input_type->tensor_type().elem_type(),
183+
" does not match existing output type of ",
184+
output_type->tensor_type().elem_type());
185+
}
186+
} else {
187+
output_type->mutable_tensor_type()->set_elem_type(
188+
input_type->tensor_type().elem_type());
189+
}
190+
} else {
191+
// This is not expected to happen
192+
fail_type_inference(
193+
"Output was expected to have tensor type. Got ",
194+
output_type->value_case());
195+
}
196+
}
197+
154198
// Note: for all methods below for propagating type or shape, callers are
155199
// responsible to handle optional inputs/outputs and ensure that the specified
156200
// index value is less than NumInputs/NumOutputs.

Diff for: onnx/shape_inference/implementation.cc

+14-3
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,16 @@ void checkShapesAndTypes(
4444
void mergeShapesAndTypes(
4545
const TypeProto_Tensor& inferredType,
4646
TypeProto_Tensor* existingType) {
47-
if (inferredType.elem_type() != TensorProto::UNDEFINED &&
48-
existingType->elem_type() == TensorProto::UNDEFINED) {
49-
existingType->set_elem_type(inferredType.elem_type());
47+
if (inferredType.elem_type() != TensorProto::UNDEFINED) {
48+
if (existingType->elem_type() == TensorProto::UNDEFINED) {
49+
existingType->set_elem_type(inferredType.elem_type());
50+
} else if (existingType->elem_type() != inferredType.elem_type()) {
51+
fail_type_inference(
52+
"type mismatch. existing=",
53+
existingType->elem_type(),
54+
" inferred=",
55+
inferredType.elem_type());
56+
}
5057
}
5158

5259
if (!inferredType.has_shape()) {
@@ -324,6 +331,10 @@ std::vector<const TypeProto*> GraphInferencerImpl::doInferencing(
324331

325332
for (int i = 0, end = numInputs; i < end; ++i) {
326333
const TypeProto* inferredInput = inputTypes[i];
334+
335+
if (!inferredInput)
336+
continue;
337+
327338
TypeProto* graphInput = g_->mutable_input(i)->mutable_type();
328339

329340
if (!graphInput->has_tensor_type()) {

Diff for: onnx/test/shape_inference_test.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -956,9 +956,6 @@ def test_scan(self): # type: () -> None
956956
# can't use self._make_graph for the subgraph as it add more inputs for the Reshape operations it inserts.
957957
# this breaks the subgraph inferencing as it expects the number of inputs passed from Scan to match
958958
# the GraphProto, but Scan knows nothing about the additional inputs.
959-
value_infos = [make_tensor_value_info('loop_state_in', TensorProto.FLOAT, (loop_state_size,)),
960-
make_tensor_value_info('input', TensorProto.FLOAT, (input_size,))]
961-
962959
input_value_infos = [make_tensor_value_info('loop_state_in', TensorProto.UNDEFINED, None),
963960
make_tensor_value_info('input', TensorProto.UNDEFINED, None)]
964961
output_value_infos = [make_tensor_value_info('loop_state_out', TensorProto.UNDEFINED, None),
@@ -969,15 +966,14 @@ def test_scan(self): # type: () -> None
969966
make_node('Identity', ['input'], ['output'])],
970967
"subgraph",
971968
input_value_infos,
972-
output_value_infos,
973-
value_info=value_infos
969+
output_value_infos
974970
)
975971

976972
graph = self._make_graph(
977973
[('loop_state_orig', TensorProto.FLOAT, (batch_size, loop_state_size)),
978974
('scan_input', TensorProto.FLOAT, (batch_size, seq_len, input_size))],
979975
[make_node('Scan', ['', 'loop_state_orig', 'scan_input'], ['loop_state_final', 'scan_output'],
980-
num_scan_inputs=1, body=subgraph)],
976+
num_scan_inputs=1, body=subgraph)],
981977
[]
982978
)
983979

@@ -1053,6 +1049,41 @@ def test_onehot_with_axis(self): # type: () -> None
10531049
[])
10541050
self._assert_inferred(graph, [make_tensor_value_info('Y', TensorProto.FLOAT, (2, None, 3, 5))]) # type: ignore
10551051

1052+
def test_loop(self): # type: () -> None
1053+
# can't use self._make_graph for the subgraph as it add more inputs for the Reshape operations it inserts.
1054+
# this breaks the subgraph inferencing as it expects the number of inputs passed from Loop to match
1055+
# the GraphProto, but Loop knows nothing about the additional inputs.
1056+
input_value_infos = [make_tensor_value_info('iter_num_in', TensorProto.INT64, (1,)),
1057+
make_tensor_value_info('cond_in', TensorProto.UNDEFINED, None),
1058+
make_tensor_value_info('loop_state_in', TensorProto.UNDEFINED, ())]
1059+
output_value_infos = [make_tensor_value_info('cond_out', TensorProto.UNDEFINED, None),
1060+
make_tensor_value_info('loop_state_out', TensorProto.UNDEFINED, None),
1061+
make_tensor_value_info('output', TensorProto.FLOAT, (3,))]
1062+
1063+
subgraph = helper.make_graph(
1064+
[make_node('Identity', ['cond_in'], ['cond_out']),
1065+
make_node('Identity', ['loop_state_in'], ['loop_state_out']),
1066+
make_node('Identity', ['outer_scope_input'], ['output'])],
1067+
"subgraph",
1068+
input_value_infos,
1069+
output_value_infos
1070+
)
1071+
1072+
graph = self._make_graph(
1073+
[('max_trip_count', TensorProto.INT64, (1,)),
1074+
('cond_orig', TensorProto.FLOAT, (1,)),
1075+
('loop_state_orig', TensorProto.FLOAT, (2,)),
1076+
('outer_scope_input', TensorProto.FLOAT, (3,))],
1077+
[make_node('Loop', ['max_trip_count', 'cond_orig', 'loop_state_orig'], ['loop_state_final', 'loop_output'],
1078+
body=subgraph)],
1079+
[]
1080+
)
1081+
1082+
self._assert_inferred(
1083+
graph,
1084+
[make_tensor_value_info('loop_state_final', TensorProto.FLOAT, None), # shape may change between iterations
1085+
make_tensor_value_info('loop_output', TensorProto.FLOAT, (None, 3))]) # type: ignore
1086+
10561087

10571088
if __name__ == '__main__':
10581089
unittest.main()

0 commit comments

Comments
 (0)