Skip to content

Commit 9d80e43

Browse files
authored
Updated support status for Cast. It is partial as can only cast from string to float32/float64/int32/int64. (onnx#642)
1 parent 095b51b commit 9d80e43

File tree

8 files changed

+117
-34
lines changed

8 files changed

+117
-34
lines changed

doc/support_status.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# ONNX-Tensorflow Support Status
22
|||
33
|-:|:-|
4-
5-
|ONNX-Tensorflow Version|Master ( commit id: 4748f3ea8135057cccfc5e30c0f4339a913d4ebd )|
4+
|ONNX-Tensorflow Version|Master ( commit id: ee0e90567b9772fab4974607bcc64fe7d410865b )|
65
|ONNX Version|Master ( commit id: cc2230603422bae893d5bc900d2d773ab34400a4 )|
76
|Tensorflow Version|v2.2.0|
87

@@ -22,16 +21,16 @@ Notes:
2221
|Acosh|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Acosh|
2322
|Add|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**:small_red_triangle:|Add|
2423
|And|**1**|1|1|1|1|1|**7**|7|7|7|7|7|7|And|
25-
|ArgMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**:small_red_triangle:|ArgMax|
26-
|ArgMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**:small_red_triangle:|ArgMin|
24+
|ArgMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**|ArgMax|
25+
|ArgMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**|ArgMin|
2726
|Asin|-|-|-|-|-|-|**7**|7|7|7|7|7|7|Asin|
2827
|Asinh|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Asinh|
2928
|Atan|-|-|-|-|-|-|**7**|7|7|7|7|7|7|Atan|
3029
|Atanh|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Atanh|
3130
|AveragePool|**1**|1|1|1|1|1|**7**|7|7|**10**|**11**|11|11|AveragePool|
3231
|BatchNormalization|**1**|1|1|1|1|**6**|**7**|7|**9**|9|9|9|9|BatchNormalization|
3332
|BitShift|-|-|-|-|-|-|-|-|-|-|**11**|11|11|BitShift|
34-
|Cast|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**9**:small_orange_diamond:|9:small_orange_diamond:|9:small_orange_diamond:|9:small_orange_diamond:|**13**:small_red_triangle:|Cast|
33+
|Cast|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**9**:small_orange_diamond:|9:small_orange_diamond:|9:small_orange_diamond:|9:small_orange_diamond:|**13**:small_orange_diamond:|Cast|
3534
|Ceil|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**:small_red_triangle:|Ceil|
3635
|Celu|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|12:small_red_triangle:|Celu|
3736
|Clip|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**11**:small_orange_diamond:|**12**:small_orange_diamond:|**13**:small_orange_diamond:|Clip|
@@ -124,9 +123,9 @@ Notes:
124123
|ReduceL2|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceL2|
125124
|ReduceLogSum|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceLogSum|
126125
|ReduceLogSumExp|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceLogSumExp|
127-
|ReduceMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**:small_red_triangle:|**13**:small_red_triangle:|ReduceMax|
126+
|ReduceMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**:small_red_triangle:|ReduceMax|
128127
|ReduceMean|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceMean|
129-
|ReduceMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**:small_red_triangle:|**13**:small_red_triangle:|ReduceMin|
128+
|ReduceMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**:small_red_triangle:|ReduceMin|
130129
|ReduceProd|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceProd|
131130
|ReduceSum|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceSum|
132131
|ReduceSumSquare|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceSumSquare|
@@ -180,10 +179,10 @@ Notes:
180179
|Where|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Where|
181180
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|7|7|Xor|
182181

183-
ONNX-TF Supported Operators / ONNX Operators: 80 / 162
182+
ONNX-TF Supported Operators / ONNX Operators: 83 / 162
184183

185184
Notes:
186-
1. Cast: Cast string to float32/float64/int32/int64 are not supported in Tensorflow.
185+
1. Cast: Cast string to data types other than float32/float64/int32/int64 is not supported in Tensorflow
187186
2. Clip: Clip input in uint64 is not supported in Tensorflow.
188187
3. ConcatFromSequence: new_axis=1 not supported in Tensorflow.
189188
4. ConvTranspose: ConvTranspose with dilations != 1, or transposed convolution for 4D or higher are not supported in Tensorflow.

onnx_tf/common/data_type.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def tf2onnx(dtype):
3939

4040

4141
def onnx2tf(dtype):
42+
# The onnx2tf is done by going to a np type first. However,
43+
# given that there is no bfloat16 in np at this time, we need
44+
# to go directly to tf bfloat16 for now.
45+
if dtype == int(TensorProto.BFLOAT16):
46+
return tf.as_dtype("bfloat16")
4247
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
4348

4449

onnx_tf/handlers/backend/arg_max.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,7 @@ def version_11(cls, node, **kwargs):
4343
@classmethod
4444
def version_12(cls, node, **kwargs):
4545
return cls._common(node, **kwargs)
46+
47+
@classmethod
48+
def version_13(cls, node, **kwargs):
49+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/arg_min.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,7 @@ def version_11(cls, node, **kwargs):
4343
@classmethod
4444
def version_12(cls, node, **kwargs):
4545
return cls._common(node, **kwargs)
46+
47+
@classmethod
48+
def version_13(cls, node, **kwargs):
49+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/cast.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,20 @@
66
from onnx_tf.handlers.handler import partial_support
77
from onnx_tf.handlers.handler import ps_description
88

9-
109
@onnx_op("Cast")
1110
@tf_func(tf.cast)
1211
@partial_support(True)
13-
@ps_description("Cast string to float32/float64/int32/int64 " +
14-
"are not supported in Tensorflow.")
12+
@ps_description("Cast string to data types other than " +
13+
"float32/float64/int32/int64 is not supported in Tensorflow")
14+
1515
class Cast(BackendHandler):
1616

1717
@classmethod
1818
def get_attrs_processor_param(cls):
1919
return {"rename": {"to": "dtype"}}
2020

2121
@classmethod
22-
def version_1(cls, node, **kwargs):
23-
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
24-
25-
@classmethod
26-
def version_6(cls, node, **kwargs):
27-
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
28-
29-
@classmethod
30-
def version_9(cls, node, **kwargs):
22+
def _common(cls, node, **kwargs):
3123
inp = kwargs["tensor_dict"][node.inputs[0]]
3224
to_type = node.attrs.get("to")
3325

@@ -42,3 +34,20 @@ def version_9(cls, node, **kwargs):
4234
return [tf.strings.to_number(inp, to_type)]
4335

4436
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
37+
38+
@classmethod
39+
def version_1(cls, node, **kwargs):
40+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
41+
42+
@classmethod
43+
def version_6(cls, node, **kwargs):
44+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
45+
46+
@classmethod
47+
def version_9(cls, node, **kwargs):
48+
return cls._common(node, **kwargs)
49+
50+
@classmethod
51+
def version_13(cls, node, **kwargs):
52+
return cls._common(node, **kwargs)
53+

onnx_tf/opset_version.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
'Adam': [],
77
'Add': [1, 6, 7],
88
'And': [1, 7],
9-
'ArgMax': [1, 11, 12],
10-
'ArgMin': [1, 11, 12],
9+
'ArgMax': [1, 11, 12, 13],
10+
'ArgMin': [1, 11, 12, 13],
1111
'ArrayFeatureExtractor': [],
1212
'Asin': [7],
1313
'Asinh': [9],
@@ -17,7 +17,7 @@
1717
'BatchNormalization': [1, 6, 7, 9],
1818
'Binarizer': [],
1919
'BitShift': [11],
20-
'Cast': [1, 6, 9],
20+
'Cast': [1, 6, 9, 13],
2121
'CastMap': [],
2222
'CategoryMapper': [],
2323
'Ceil': [1, 6],
@@ -189,8 +189,8 @@
189189
}
190190

191191
backend_partial_support = {
192-
'Cast': 'Cast string to float32/float64/int32/int64 are not supported in '
193-
'Tensorflow.',
192+
'Cast': 'Cast string to data types other than float32/float64/int32/int64 '
193+
'is not supported in Tensorflow',
194194
'Clip': 'Clip input in uint64 is not supported in Tensorflow.',
195195
'ConcatFromSequence': 'new_axis=1 not supported in Tensorflow.',
196196
'ConvTranspose': 'ConvTranspose with dilations != 1, or transposed '

test/backend/test_model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,42 @@ def test_relu_node_inplace(self):
183183
output = tf_rep.run({"X": X})
184184
np.testing.assert_almost_equal(output.X1, Y_ref)
185185

186+
def test_argmax_node_bfloat(self):
187+
X = np.random.randn(2, 8).astype(np.float32)
188+
Y_ref = np.argmax(X, axis=0)
189+
190+
graph_def = helper.make_graph(
191+
[
192+
helper.make_node("Cast", ["X0"], ["X1"], to=TensorProto.BFLOAT16),
193+
helper.make_node("ArgMax", ["X1"], ["X2"], axis=0, keepdims=0, select_last_index=0)
194+
],
195+
name="test",
196+
inputs=[helper.make_tensor_value_info("X0", TensorProto.FLOAT, [2, 8])],
197+
outputs=[
198+
helper.make_tensor_value_info("X2", TensorProto.BFLOAT16, [2, 8])
199+
])
200+
tf_rep = prepare(helper.make_model(graph_def))
201+
output = tf_rep.run({"X0": X})
202+
np.testing.assert_almost_equal(output.X2, Y_ref)
203+
204+
def test_argmin_node_bfloat(self):
205+
X = np.random.randn(2, 8).astype(np.float32)
206+
Y_ref = np.argmin(X, axis=0)
207+
208+
graph_def = helper.make_graph(
209+
[
210+
helper.make_node("Cast", ["X0"], ["X1"], to=TensorProto.BFLOAT16),
211+
helper.make_node("ArgMin", ["X1"], ["X2"], axis=0, keepdims=0, select_last_index=0)
212+
],
213+
name="test",
214+
inputs=[helper.make_tensor_value_info("X0", TensorProto.FLOAT, [2, 8])],
215+
outputs=[
216+
helper.make_tensor_value_info("X2", TensorProto.BFLOAT16, [2, 8])
217+
])
218+
tf_rep = prepare(helper.make_model(graph_def))
219+
output = tf_rep.run({"X0": X})
220+
np.testing.assert_almost_equal(output.X2, Y_ref)
221+
186222
def test_initializer(self):
187223
if legacy_onnx_pre_ver(1, 2):
188224
raise unittest.SkipTest(

test/backend/test_node.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import print_function
44
from __future__ import unicode_literals
55

6+
import sys
67
import math
78
import unittest
89
import numpy as np
@@ -114,6 +115,7 @@ def test_arg_max(self):
114115
result = data.shape[axis] - result - 1
115116
np.testing.assert_almost_equal(output["reduced"], result)
116117

118+
117119
def test_arg_min(self):
118120
for axis in [0, 1]:
119121
node_def = helper.make_node("ArgMin", ["data"], ["reduced"],
@@ -147,6 +149,7 @@ def test_arg_min(self):
147149
result = data.shape[axis] - result - 1
148150
np.testing.assert_almost_equal(output["reduced"], result)
149151

152+
150153
def test_asinh(self):
151154
if legacy_opset_pre_ver(9):
152155
raise unittest.SkipTest("ONNX version {} doesn't support Asinh.".format(
@@ -196,32 +199,45 @@ def test_batch_normalization(self):
196199

197200
def test_cast(self):
198201
if legacy_onnx_pre_ver(1, 2) or legacy_opset_pre_ver(6):
199-
test_cases = [("FLOAT", tf.float32), ("UINT8", tf.uint8),
202+
test_cases = [("FLOAT", tf.float32),
203+
("UINT8", tf.uint8),
200204
("INT8", tf.int8),
201-
("UINT16", tf.uint16), ("INT16", tf.int16),
202-
("INT32", tf.int32), ("INT64", tf.int64), ("BOOL", tf.bool),
203-
("FLOAT16", tf.float16), ("DOUBLE", tf.float64),
204-
("COMPLEX64", tf.complex64), ("COMPLEX128", tf.complex128)]
205+
("UINT16", tf.uint16),
206+
("INT16", tf.int16),
207+
("INT32", tf.int32),
208+
("INT64", tf.int64),
209+
("BOOL", tf.bool),
210+
("FLOAT16", tf.float16),
211+
("DOUBLE", tf.float64),
212+
("COMPLEX64", tf.complex64),
213+
("COMPLEX128", tf.complex128)]
205214
else:
206215
test_cases = [(TensorProto.FLOAT, tf.float32),
207-
(TensorProto.UINT8, tf.uint8), (TensorProto.INT8, tf.int8),
216+
(TensorProto.UINT8, tf.uint8),
217+
(TensorProto.INT8, tf.int8),
208218
(TensorProto.UINT16, tf.uint16),
209219
(TensorProto.INT16, tf.int16),
210220
(TensorProto.INT32, tf.int32),
211-
(TensorProto.INT64, tf.int64), (TensorProto.BOOL, tf.bool),
221+
(TensorProto.INT64, tf.int64),
222+
(TensorProto.BOOL, tf.bool),
212223
(TensorProto.FLOAT16, tf.float16),
213224
(TensorProto.DOUBLE, tf.float64),
214225
(TensorProto.COMPLEX64, tf.complex64),
215226
(TensorProto.COMPLEX128, tf.complex128)]
216227
if not legacy_opset_pre_ver(9):
217228
test_cases.append((TensorProto.STRING, tf.string))
229+
# added casting to bfloat16 from number in opset 13
230+
if not legacy_opset_pre_ver(13):
231+
test_cases.append((TensorProto.BFLOAT16, tf.bfloat16))
218232
for ty, tf_type in test_cases:
219233
node_def = helper.make_node("Cast", ["input"], ["output"], to=ty)
220234
vector = [2, 3]
221235
output = run_node(node_def, [vector])
222236
np.testing.assert_equal(output["output"].dtype, tf_type)
223-
224237
if not legacy_opset_pre_ver(9):
238+
# test_cases2 is focused on Strings to Numbers
239+
# Note: casting from string to bfloat16 is not allowed by tf.strings.to_number
240+
# so no BFLOAT16 in test_cases2.
225241
test_cases2 = [(TensorProto.FLOAT, tf.float32),
226242
(TensorProto.INT32, tf.int32),
227243
(TensorProto.INT64, tf.int64),
@@ -232,6 +248,16 @@ def test_cast(self):
232248
output = run_node(node_def, [vector])
233249
np.testing.assert_equal(output["output"].dtype, tf_type)
234250

251+
if not legacy_opset_pre_ver(9):
252+
# test_case3 is focused on Strings to float and the special floating-point values.
253+
test_cases3 = [(TensorProto.FLOAT, tf.float32),
254+
(TensorProto.DOUBLE, tf.float64)]
255+
for ty, tf_type in test_cases3:
256+
node_def = helper.make_node("Cast", ["input"], ["output"], to=ty)
257+
vector = ['3.14159', '1e-5', '1E8', 'NaN', '-INF', '+INF']
258+
output = run_node(node_def, [vector])
259+
np.testing.assert_equal(output["output"].dtype, tf_type)
260+
235261
def test_ceil(self):
236262
node_def = helper.make_node("Ceil", ["X"], ["Y"])
237263
x = self._get_rnd_float32(shape=[1000])

0 commit comments

Comments
 (0)