Skip to content

Commit 843cb4f

Browse files
authored
Support optional operators (onnx#980)
* add optional data type support Signed-off-by: degaochu <[email protected]> * fix seq test model issue Signed-off-by: degaochu <[email protected]> * update support status Signed-off-by: degaochu <[email protected]> * support identity opset 16 Signed-off-by: degaochu <[email protected]> * fix issue Signed-off-by: degaochu <[email protected]> * fix issues for onnx 1.9.0 Signed-off-by: degaochu <[email protected]> * fix attr proto cross version issue Signed-off-by: degaochu <[email protected]> * update opset list and status doc Signed-off-by: degaochu <[email protected]>
1 parent 4500476 commit 843cb4f

12 files changed

+295
-199
lines changed

doc/support_status.md

Lines changed: 175 additions & 175 deletions
Large diffs are not rendered by default.

onnx_tf/backend.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
240240
node = OnnxNode(node)
241241
input_tensors = []
242242
for i in inputs:
243-
input_tensors.append(tf.constant(i))
243+
if i is None:
244+
input_tensors.append(i)
245+
else:
246+
input_tensors.append(tf.constant(i))
244247

245248
if isinstance(inputs, dict):
246249
feed_dict_raw = inputs
@@ -252,7 +255,15 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
252255
input_dict = {}
253256
for k, v in feed_dict_raw.items():
254257
if isinstance(v, list):
255-
input_dict[k] = [tf.constant(x) for x in v]
258+
list_input = []
259+
for x in v:
260+
if x is None:
261+
list_input.append(x)
262+
else:
263+
list_input.append(tf.constant(x))
264+
input_dict[k] = list_input
265+
elif v is None: # keep None for empty optional data
266+
input_dict[k] = v
256267
else:
257268
input_dict[k] = tf.constant(v)
258269

onnx_tf/backend_rep.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def run(self, inputs, **kwargs):
7878
if isinstance(inputs, dict):
7979
feed_dict = inputs
8080
elif isinstance(inputs, list) or isinstance(inputs, tuple):
81+
#to do: handle input is seq(optional) and optional element is empty(if/loop opset 16)
8182
if len(self.inputs) != len(inputs):
8283
raise RuntimeError('Expected {} values for uninitialized '
8384
'graph inputs ({}), but got {}.'.format(
@@ -91,7 +92,15 @@ def run(self, inputs, **kwargs):
9192
input_dict = {}
9293
for k, v in feed_dict.items():
9394
if isinstance(v, list):
94-
input_dict[k] = [tf.constant(x) for x in v]
95+
list_input = []
96+
for x in v:
97+
if x is None:
98+
list_input.append(x)
99+
else:
100+
list_input.append(tf.constant(x))
101+
input_dict[k] = list_input
102+
elif v is None: # keep None for empty optional data
103+
input_dict[k] = v
95104
else:
96105
input_dict[k] = tf.constant(v)
97106

@@ -109,6 +118,8 @@ def run(self, inputs, **kwargs):
109118
o_values.append(v_list)
110119
elif isinstance(output_values[o_name], tf.Tensor):
111120
o_values.append(output_values[o_name].numpy())
121+
elif isinstance(output_values[o_name], tf.RaggedTensor):
122+
o_values.append([i for i in output_values[o_name].numpy()])
112123
else:
113124
o_values.append(output_values[o_name])
114125

onnx_tf/common/attr_converter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from onnx_tf.common import IS_PYTHON3
2+
from onnx_tf.common.legacy import legacy_opset_pre_ver
23

34

45
def convert_tf(attr):
@@ -82,5 +83,7 @@ def __convert_onnx_attribute_proto(attr_proto):
8283
return str_list
8384
elif attr_proto.HasField('sparse_tensor'):
8485
return attr_proto.sparse_tensor
86+
elif not legacy_opset_pre_ver(15) and attr_proto.HasField('tp'):
87+
return attr_proto.tp
8588
else:
8689
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))

onnx_tf/handlers/backend/identity.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,17 @@ def version_13(cls, node, **kwargs):
1818
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
1919

2020
@classmethod
21-
def version_14(cls, node, **kwargs):
21+
def _common(cls, node, **kwargs):
2222
x = kwargs["tensor_dict"][node.inputs[0]]
2323
if isinstance(x, (list, tuple)):
2424
return [tf.identity_n(x)]
2525
else:
2626
return [tf.identity(x)]
27+
28+
@classmethod
29+
def version_14(cls, node, **kwargs):
30+
return cls._common(node, **kwargs)
31+
32+
@classmethod
33+
def version_16(cls, node, **kwargs):
34+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/optional.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from onnx_tf.handlers.backend_handler import BackendHandler
2+
from onnx_tf.handlers.handler import onnx_op
3+
4+
5+
@onnx_op("Optional")
6+
class Optional(BackendHandler):
7+
8+
@classmethod
9+
def version_15(cls, node, **kwargs):
10+
if len(node.inputs) > 0:
11+
return [kwargs["tensor_dict"][node.inputs[0]]]
12+
else:
13+
return [None]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from onnx_tf.handlers.backend_handler import BackendHandler
2+
from onnx_tf.handlers.handler import onnx_op
3+
4+
5+
@onnx_op("OptionalGetElement")
6+
class OptionalGetElement(BackendHandler):
7+
8+
@classmethod
9+
def version_15(cls, node, **kwargs):
10+
if len(node.inputs) > 0:
11+
return [kwargs["tensor_dict"][node.inputs[0]]]
12+
else:
13+
raise RuntimeError("No element value!.")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from onnx_tf.handlers.backend_handler import BackendHandler
2+
from onnx_tf.handlers.handler import onnx_op
3+
4+
5+
@onnx_op("OptionalHasElement")
6+
class OptionalHasElement(BackendHandler):
7+
8+
@classmethod
9+
def version_15(cls, node, **kwargs):
10+
if len(node.inputs) > 0 and kwargs["tensor_dict"][node.inputs[0]] is not None:
11+
return [True]
12+
else:
13+
return [False]

onnx_tf/opset_version.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
'HardSigmoid': [1, 6],
6969
'HardSwish': [14],
7070
'Hardmax': [1, 11, 13],
71-
'Identity': [1, 13, 14],
71+
'Identity': [1, 13, 14, 16],
7272
'If': [1, 11, 13],
7373
'ImageScaler': [1],
7474
'Imputer': [],
@@ -109,9 +109,9 @@
109109
'Not': [1],
110110
'OneHot': [9, 11],
111111
'OneHotEncoder': [],
112-
'Optional': [],
113-
'OptionalGetElement': [],
114-
'OptionalHasElement': [],
112+
'Optional': [15],
113+
'OptionalGetElement': [15],
114+
'OptionalHasElement': [15],
115115
'Or': [1, 7],
116116
'PRelu': [1, 6, 7, 9],
117117
'Pad': [1, 2, 11, 13],

test/backend/test_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -364,16 +364,16 @@ def test_if_with_sequence(self):
364364
'c': c,
365365
'cond': np.array(True, dtype=np.bool)
366366
})
367-
np.testing.assert_almost_equal(output['S_final'].values[:2], a)
368-
np.testing.assert_almost_equal(output['S_final'].values[2:], b)
367+
np.testing.assert_almost_equal(output['S_final'][0], a)
368+
np.testing.assert_almost_equal(output['S_final'][1], b)
369369
output = tf_rep.run({
370370
'a': a,
371371
'b': b,
372372
'c': c,
373373
'cond': np.array(False, dtype=np.bool)
374374
})
375-
np.testing.assert_almost_equal(output['S_final'].values[:2], a)
376-
np.testing.assert_almost_equal(output['S_final'].values[2:], c)
375+
np.testing.assert_almost_equal(output['S_final'][0], a)
376+
np.testing.assert_almost_equal(output['S_final'][1], c)
377377

378378
def test_initializer(self):
379379
if legacy_onnx_pre_ver(1, 2):
@@ -460,10 +460,10 @@ def test_loop_with_sequence(self):
460460
outputs=[s_final_out])
461461
tf_rep = prepare(helper.make_model(graph_def))
462462
output = tf_rep.run({'a': a, 'b': b, 'M': M, 'cond_init': cond})
463-
np.testing.assert_almost_equal(output['S_final'].values[:2], a)
464-
np.testing.assert_almost_equal(output['S_final'].values[2:3], b)
465-
np.testing.assert_almost_equal(output['S_final'].values[3:4], b)
466-
np.testing.assert_almost_equal(output['S_final'].values[4:5], b)
463+
np.testing.assert_almost_equal(output['S_final'][0], a)
464+
np.testing.assert_almost_equal(output['S_final'][1], b)
465+
np.testing.assert_almost_equal(output['S_final'][2], b)
466+
np.testing.assert_almost_equal(output['S_final'][3], b)
467467

468468
def test_pow_bfloat16(self):
469469
X1 = np.array([1, 2, 3]).astype(np.float32)

test/backend/test_node.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4198,7 +4198,28 @@ def test_where(self):
41984198
output = run_node(node_def, [c, x, y])
41994199
np.testing.assert_almost_equal(output["Z"], np.where(c, x, y))
42004200

4201+
def test_optional(self):
4202+
if legacy_opset_pre_ver(16):
4203+
raise unittest.SkipTest("ONNX version {} doesn't support Optional.".format(
4204+
defs.onnx_opset_version()))
4205+
ten_in_tp = helper.make_tensor_type_proto(TensorProto.FLOAT, shape=[5])
4206+
opt_in_tp = helper.make_optional_type_proto(ten_in_tp)
4207+
node_def = helper.make_node("Optional", ["X"], ["Y"], type=opt_in_tp)
4208+
x = self._get_rnd_float32(-3.0, 3.0, [5])
4209+
output = run_node(node_def, [x])
4210+
np.testing.assert_almost_equal(output["Y"], x)
42014211

4212+
def test_optional_empty(self):
4213+
if legacy_opset_pre_ver(16):
4214+
raise unittest.SkipTest("ONNX version {} doesn't support Optional.".format(
4215+
defs.onnx_opset_version()))
4216+
ten_in_tp = helper.make_tensor_type_proto(TensorProto.FLOAT, shape=[5])
4217+
opt_in_tp = helper.make_optional_type_proto(ten_in_tp)
4218+
node_def = helper.make_node("Optional", ["X"], ["Y"], type=opt_in_tp)
4219+
x = None
4220+
output = run_node(node_def, [x])
4221+
np.testing.assert_equal(output["Y"], x)
4222+
42024223
if __name__ == '__main__':
42034224
unittest.main()
42044225

test/backend/test_onnx_backend.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ def get_onnx_supported_ops():
7676
if legacy_opset_pre_ver(13):
7777
backend_test.exclude(r'test_cumsum_[a-z,_]*')
7878

79-
# Currently ONNX's backend test runner does not support sequence as input/output
80-
backend_test.exclude(r'test_if_seq[a-z,_]*')
81-
8279
# TF session run does not support sequence/RaggedTensor as model inputs
8380
backend_test.exclude(r'test_loop13_seq[a-z,_]*')
8481

@@ -149,9 +146,10 @@ def get_onnx_supported_ops():
149146
backend_test.exclude(r'test_bernoulli_double_expanded[a-z,_]*')
150147
backend_test.exclude(r'test_bernoulli_seed_expanded[a-z,_]*')
151148

152-
# Exclude optional_get_element, test_optional_has_element tests
153-
backend_test.exclude(r'test_optional_get_element[a-z,_]*')
154-
backend_test.exclude(r'test_optional_has_element[a-z,_]*')
149+
# # onnx backend test support seq from 1.11 #3731
150+
if legacy_opset_pre_ver(16):
151+
backend_test.exclude(r'test_optional_get_element[a-z,_]*')
152+
backend_test.exclude(r'test_optional_has_element[a-z,_]*')
155153

156154
# Exclude BatchNormalization with training_mode=1 tests
157155
backend_test.exclude(r'test_batchnorm_epsilon_training_mode[a-z,_]*')
@@ -161,8 +159,13 @@ def get_onnx_supported_ops():
161159
if legacy_opset_pre_ver(15):
162160
backend_test.exclude(r'[a-z,_]*identity_sequence_[a-z,_]*')
163161

164-
# Exclude tests with optional inputs
165-
backend_test.exclude(r'[a-z,_]*identity_opt_[a-z,_]*')
162+
# onnx backend test support seq from 1.11 #3731
163+
if legacy_opset_pre_ver(16):
164+
backend_test.exclude(r'[a-z,_]*identity_opt_[a-z,_]*')
165+
166+
if legacy_opset_pre_ver(16):
167+
backend_test.exclude(r'test_if_seq[a-z,_]*')
168+
166169
backend_test.exclude(r'[a-z,_]*if_opt_[a-z,_]*')
167170
backend_test.exclude(r'[a-z,_]*loop16_seq_none_[a-z,_]*')
168171

0 commit comments

Comments
 (0)