Skip to content

Commit 216fa97

Browse files
Ac2zoomhouseroad
authored andcommitted
Broadcast Version Conversion Adapters (Add, Mul, Gemm) (onnx#1284)
* Squashing all commits into one * Fixed supported adapters list * Long Int troubles * Will int64s save the day? * Remove comments in test_backend_test * Updated broadcastable methods to accurately reflect purpose * Attempt to resolve unsigned int issue * Fixed trans bug in gemm * New Gemm_6_7 * Fixed axis index * Changed name of CompatibleAdapter and added description in BroadcastBackwardCompatible * Attempt to resolve int64_t conversion issues in helper * Unsigned int issue * New approach to int64_t issue and novel numpy_multibroadcastable approach * Unsqueeze instead of Reshape for broadcast_forward_compatibility, more assertions of sizes not being params * Made unallowed_types a constructor parameter for TypeRestriction * Changed long int back to int64_t in broadcast_forward_compatibility * renamed opset7_requires_broadcasting * Abstracted input validation * Switched num_inputs in assertInputsAvailable from int to int64_t * switch int64_t to uint64_t * Removed test_mvn * One more cast to int from int64_t in assertInputsAvailable * Addressed code review except numpy_unibroadcastable * insertBefore instead of moveBefore, changing BackwardsCompatibleTest, setting sizes of unsqueeze output * fixed naming of broadcasting assertions * Changed assert_numpy_unibroadcastable_and_require_broadcast to check_numpy_unibroadcastable_and_require_broadcast
1 parent 6146a85 commit 216fa97

23 files changed

+652
-39
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ compile_commands.json
8686
.coverage
8787
onnx/examples/.coverage.nbval
8888
.pytest_cache
89+
test_report
8990

9091
# autocomplete
9192
.ycm_extra_conf.py

docs/VersionConverter.md

+10-8
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,21 @@ ModelProto ConvertVersion(
3232
const OpSetID& target_version);
3333
```
3434

35-
which accepts an input `ModelProto`, the initial opset version of the model,
36-
and the target opset verison, and which returns a new `ModelProto` which
35+
which accepts an input `ModelProto`, the initial opset version of the model,
36+
and the target opset verison, and which returns a new `ModelProto` which
3737
is the result of apply all relevant adapters between initial_version and
38-
target_version. For a list of available passes, see
38+
target_version. For a list of available passes, see
3939
[convert.h](onnx/version_converter/convert.h).
4040

4141
Implementing Adapters
4242

43-
You can implement a new adapter by subclassing `Adapter`, and registering
44-
your new adapter with `VersionConverter::registerAdapter()`. Adapters operate
45-
on an in-memory graph representation defined in [ir.h](onnx/common/ir.h).
46-
There are a number of examples in the [adapters](onnx/version_converter/adapters)
47-
directory.
43+
You can implement a new adapter by subclassing `Adapter`, and registering
44+
your new adapter with `VersionConverter::registerAdapter()`. Adapters operate
45+
on an in-memory graph representation defined in [ir.h](onnx/common/ir.h).
46+
There are a number of examples in the [adapters](onnx/version_converter/adapters)
47+
directory. Please ensure that all adapters convert from opset version i to i + 1
48+
or i - 1, i.e. from Version 6 to Version 5 or vice versa, even if the 2 versions
49+
being converted between are Version 1 and Version 6.
4850

4951
If your adapter applies in the default domain, please consider adding it
5052
to the core ONNX repository

onnx/common/interned_strings.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ _(body) \
136136
_(then_branch) \
137137
_(else_branch) \
138138
_(Captured) \
139-
_(__control_inputs)
139+
_(__control_inputs) \
140+
_(count_include_pad) \
141+
_(storage_order) \
142+
_(Unsqueeze)
140143

141144
enum BuiltinSymbol {
142145
#define DEFINE_SYMBOL(s) \

onnx/common/ir.h

-1
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,6 @@ class OpSetID final {
816816
}
817817
}
818818

819-
820819
const std::string& domain() const {
821820
return domain_;
822821
}

onnx/common/ir_pb_converter.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -507,18 +507,17 @@ void encodeGraph(GraphProto * p_g, const std::shared_ptr<Graph> & g) {
507507
}
508508
for(auto output : node->outputs()) {
509509
p_n->add_output(value_name(output));
510-
511510
// only save it if
512511
// - it has actual information worth saving
513512
// - it's not already saved in the graph outputs value info
514513
if (graph_outputs.find(output) != graph_outputs.end()) {
515514
continue;
516515
}
517-
if (output->elemType() == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED &&
516+
if (output->elemType() == TensorProto_DataType_UNDEFINED &&
518517
output->sizes().empty()) {
519518
continue;
520519
}
521-
ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_value_info();
520+
ValueInfoProto* v = p_g->add_value_info();
522521
encodeValueInfo(v, output);
523522
}
524523
p_n->set_op_type(node->kind().toString());

onnx/cpp2py_export.cc

+1
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
268268
[](const py::bytes& bytes, const py::int_ target) {
269269
ModelProto proto{};
270270
ParseProtoFromPyBytes(&proto, bytes);
271+
shape_inference::InferShapes(proto);
271272
auto const result = version_conversion::ConvertVersion(std::move(proto),
272273
target);
273274
std::string out;

onnx/optimizer/optimize.cc

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
namespace ONNX_NAMESPACE { namespace optimization {
77

8+
// TODO: Remove this static reference
89
static Optimizer _optimizer;
910

1011
ModelProto Optimize(

onnx/test/test_backend_test.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from onnx.backend.base import Device, DeviceType
1313
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
1414
import onnx.shape_inference
15+
import onnx.version_converter
1516
from typing import Optional, Text, Any, Tuple, Sequence
1617
from onnx import NodeProto, ModelProto, TensorProto
1718
import numpy # type: ignore
@@ -41,7 +42,7 @@ def prepare(cls,
4142
model = onnx.shape_inference.infer_shapes(model)
4243
value_infos = {vi.name: vi for vi in itertools.chain(model.graph.value_info, model.graph.output)}
4344

44-
if do_enforce_shape_inference_coverage(model):
45+
if do_enforce_test_coverage_whitelist(model):
4546
for node in model.graph.node:
4647
for i, output in enumerate(node.output):
4748
if node.op_type == 'Dropout' and i != 0:
@@ -75,13 +76,13 @@ def supports_device(cls, device): # type: (Text) -> bool
7576
return False
7677

7778

78-
shape_coverage_whitelist = set(
79+
test_coverage_whitelist = set(
7980
['bvlc_alexnet', 'densenet121', 'inception_v1', 'inception_v2',
8081
'resnet50', 'shufflenet', 'SingleRelu', 'squeezenet_old', 'vgg19', 'zfnet'])
8182

8283

83-
def do_enforce_shape_inference_coverage(model): # type: (ModelProto) -> bool
84-
if model.graph.name not in shape_coverage_whitelist:
84+
def do_enforce_test_coverage_whitelist(model): # type: (ModelProto) -> bool
85+
if model.graph.name not in test_coverage_whitelist:
8586
return False
8687
for node in model.graph.node:
8788
if node.op_type in set(['RNN', 'LSTM', 'GRU']):

onnx/test/version_converter_test.py

+100-5
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,20 @@ def _converted(
3131
# Test 1: Backwards Incompatible Conversion: Reshape: 8 -> 2
3232
def test_backwards_incompatible(self): # type: () -> None
3333
def test(): # type: () -> None
34-
nodes = [helper.make_node('Reshape', ["X", "shape"], ["Y"])]
34+
nodes = [helper.make_node('Add', ["W", "Z"], ["shape"]),
35+
helper.make_node('Reshape', ["X", "shape"], ["A"]),
36+
helper.make_node('Add', ["A", "W"], ["Y"])]
3537
graph = helper.make_graph(
3638
nodes,
3739
"test",
3840
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5,)),
39-
helper.make_tensor_value_info("shape", TensorProto.FLOAT, (1,))],
41+
helper.make_tensor_value_info("W", TensorProto.FLOAT, (1,)),
42+
helper.make_tensor_value_info("Z", TensorProto.FLOAT, (1,))],
4043
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
4144
self._converted(graph, helper.make_operatorsetid("", 8), 2)
4245
self.assertRaises(RuntimeError, test)
4346

44-
# Test 2: Backwards Compatible Conversion: Add: 8 -> 7
47+
# Test 2: Backwards Compatible Conversion (No Adaptations): Add: 3 -> 2
4548
def test_backwards_compatible(self): # type: () -> None
4649
nodes = [helper.make_node('Add', ["X1", "X2"], ["Y"])]
4750
graph = helper.make_graph(
@@ -51,10 +54,10 @@ def test_backwards_compatible(self): # type: () -> None
5154
helper.make_tensor_value_info("X2", TensorProto.FLOAT, (5,))],
5255
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
5356
converted_model = self._converted(graph, helper.make_operatorsetid(
54-
"", 8), 7)
57+
"", 3), 2)
5558
# Assert equality of graph and converted_model
5659
assert converted_model.graph.node[0].op_type == "Add"
57-
assert converted_model.opset_import[0].version == 7
60+
assert converted_model.opset_import[0].version == 2
5861

5962
# Test 3: Non-Existent Op Conversion: Cos: 8 -> 6
6063
def test_non_existent_op(self): # type: () -> None
@@ -68,6 +71,98 @@ def test(): # type: () -> None
6871
self._converted(graph, helper.make_operatorsetid("", 8), 6)
6972
self.assertRaises(RuntimeError, test)
7073

74+
# Test Add Adapter: 8 -> 5
75+
def test_add_8_5(self): # type: () -> None
76+
nodes = [helper.make_node('Add', ["X1", "X2"], ["Y"])]
77+
graph = helper.make_graph(
78+
nodes,
79+
"test",
80+
[helper.make_tensor_value_info("X1", TensorProto.FLOAT, (5,)),
81+
helper.make_tensor_value_info("X2", TensorProto.FLOAT, (1,))],
82+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
83+
converted_model = self._converted(graph, helper.make_operatorsetid(
84+
"", 8), 5)
85+
# Assert equality of graph and converted_model
86+
assert converted_model.graph.node[0].op_type == "Add"
87+
assert converted_model.opset_import[0].version == 5
88+
89+
# Test Add Adapter: 5 -> 8
90+
def test_add_5_8(self): # type: () -> None
91+
nodes = [helper.make_node('Add', ["X1", "X2"], ["Y"])]
92+
graph = helper.make_graph(
93+
nodes,
94+
"test",
95+
[helper.make_tensor_value_info("X1", TensorProto.FLOAT, (5,)),
96+
helper.make_tensor_value_info("X2", TensorProto.FLOAT, (1,))],
97+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
98+
converted_model = self._converted(graph, helper.make_operatorsetid(
99+
"", 5), 8)
100+
# Assert equality of graph and converted_model
101+
assert converted_model.graph.node[0].op_type == "Add"
102+
assert converted_model.opset_import[0].version == 8
103+
104+
# Test Mul Adapter: 8 -> 5
105+
def test_mul_8_5(self): # type: () -> None
106+
nodes = [helper.make_node('Mul', ["X1", "X2"], ["Y"])]
107+
graph = helper.make_graph(
108+
nodes,
109+
"test",
110+
[helper.make_tensor_value_info("X1", TensorProto.FLOAT, (5,)),
111+
helper.make_tensor_value_info("X2", TensorProto.FLOAT, (1,))],
112+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
113+
converted_model = self._converted(graph, helper.make_operatorsetid(
114+
"", 8), 5)
115+
# Assert equality of graph and converted_model
116+
assert converted_model.graph.node[0].op_type == "Mul"
117+
assert converted_model.opset_import[0].version == 5
118+
119+
# Test Mul Adapter: 5 -> 8
120+
def test_mul_5_8(self): # type: () -> None
121+
nodes = [helper.make_node('Mul', ["X1", "X2"], ["Y"])]
122+
graph = helper.make_graph(
123+
nodes,
124+
"test",
125+
[helper.make_tensor_value_info("X1", TensorProto.FLOAT, (5,)),
126+
helper.make_tensor_value_info("X2", TensorProto.FLOAT, (1,))],
127+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
128+
converted_model = self._converted(graph, helper.make_operatorsetid(
129+
"", 5), 8)
130+
# Assert equality of graph and converted_model
131+
assert converted_model.graph.node[0].op_type == "Mul"
132+
assert converted_model.opset_import[0].version == 8
133+
134+
# Test Gemm Adapter: 1 -> 8
135+
def test_gemm_up(self): # type: () -> None
136+
nodes = [helper.make_node('Gemm', ["A", "B", "C"], ["Y"])]
137+
graph = helper.make_graph(
138+
nodes,
139+
"test",
140+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (5, 5,)),
141+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (5, 5,)),
142+
helper.make_tensor_value_info("C", TensorProto.FLOAT, (5, 5,))],
143+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5, 5,))])
144+
converted_model = self._converted(graph, helper.make_operatorsetid(
145+
"", 1), 8)
146+
# Assert equality of graph and converted_model
147+
assert converted_model.graph.node[0].op_type == "Gemm"
148+
assert converted_model.opset_import[0].version == 8
149+
150+
# Test Gemm Adapter: 8 -> 1
151+
def test_gemm_down(self): # type: () -> None
152+
nodes = [helper.make_node('Gemm', ["A", "B", "C"], ["Y"])]
153+
graph = helper.make_graph(
154+
nodes,
155+
"test",
156+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (5, 5,)),
157+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (5, 5,)),
158+
helper.make_tensor_value_info("C", TensorProto.FLOAT, (5, 5,))],
159+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5, 5,))])
160+
converted_model = self._converted(graph, helper.make_operatorsetid(
161+
"", 8), 1)
162+
# Assert equality of graph and converted_model
163+
assert converted_model.graph.node[0].op_type == "Gemm"
164+
assert converted_model.opset_import[0].version == 1
165+
71166

72167
if __name__ == '__main__':
73168
unittest.main()

0 commit comments

Comments
 (0)