Skip to content

Commit c57bff5

Browse files
Add data cast to fix LessOrEqual issue with tf 2.4 (onnx#839)
Add data cast to fix LessOrEqual issue with tf 2.4, since tf LessEqual doesn't support uint16, uint32, uint64. Also add test code to verify unsupported data types. Signed-off-by: Chin Huang <[email protected]> Co-authored-by: Winnie Tsang <[email protected]>
1 parent a5c4d64 commit c57bff5

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

onnx_tf/handlers/backend/less_or_equal.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,48 @@
44
from onnx_tf.handlers.handler import onnx_op
55
from onnx_tf.handlers.handler import tf_func
66
from .control_flow_mixin import ComparisonMixin
7+
from onnx_tf.common import sys_config
8+
from onnx_tf.common import exception
9+
import onnx_tf.common.data_type as data_type
710

811

912
@onnx_op("LessOrEqual")
1013
@tf_func(tf.less_equal)
1114
class LessOrEqual(ComparisonMixin, BackendHandler):
15+
cast_map = {tf.uint16: tf.int32, tf.uint32: tf.int64}
16+
supported_types = [
17+
tf.uint8, tf.int8, tf.int16, tf.int32, tf.int64, tf.float16,
18+
tf.float32, tf.float64, tf.bfloat16
19+
]
20+
21+
@classmethod
22+
def args_check(cls, node, **kwargs):
23+
# update cast map based on the auto_cast config option
24+
cls.cast_map[tf.uint64] = tf.int64 if sys_config.auto_cast else None
25+
26+
x = kwargs["tensor_dict"][node.inputs[0]]
27+
y = kwargs["tensor_dict"][node.inputs[1]]
28+
29+
# throw an error if the data type is not natively supported by
30+
# Tensorflow, cannot be safely cast, and auto_cast option is False
31+
if x.dtype in cls.cast_map and cls.cast_map[x.dtype] is None:
32+
exception.DTYPE_NOT_CAST_EXCEPT(
33+
"LessOrEqual input " + node.inputs[0] + " with data type '" +
34+
data_type.tf_to_np_str(x.dtype) + "'",
35+
data_type.tf_to_np_str_list(cls.supported_types))
36+
if y.dtype in cls.cast_map and cls.cast_map[y.dtype] is None:
37+
exception.DTYPE_NOT_CAST_EXCEPT(
38+
"LessOrEqual input " + node.inputs[1] + " with data type '" +
39+
data_type.tf_to_np_str(y.dtype) + "'",
40+
data_type.tf_to_np_str_list(cls.supported_types))
1241

1342
@classmethod
1443
def version_12(cls, node, **kwargs):
15-
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
44+
def dtype_cast(x):
45+
return tf.cast(x, cls.cast_map[x.dtype]) if x.dtype in cls.cast_map else x
46+
47+
# handle data types that are not natively supported by Tensorflow
48+
x = dtype_cast(kwargs["tensor_dict"][node.inputs[0]])
49+
y = dtype_cast(kwargs["tensor_dict"][node.inputs[1]])
50+
51+
return [cls.make_tensor_from_onnx_node(node, inputs=[x, y])]

test/backend/test_node.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,11 +1302,11 @@ def test_less_or_equal(self):
13021302
node_def = helper.make_node('LessOrEqual', ['X', 'Y'], ['Z'])
13031303
shape = [2, 3, 4, 5]
13041304
x = self._get_rnd_int(
1305-
np.iinfo(np.uint64).min,
1306-
np.iinfo(np.uint64).max, shape, np.uint64)
1305+
np.iinfo(np.int64).min,
1306+
np.iinfo(np.int64).max, shape, np.int64)
13071307
y = self._get_rnd_int(
1308-
np.iinfo(np.uint64).min,
1309-
np.iinfo(np.uint64).max, shape, np.uint64)
1308+
np.iinfo(np.int64).min,
1309+
np.iinfo(np.int64).max, shape, np.int64)
13101310
output = run_node(node_def, [x, y])
13111311
np.testing.assert_equal(output['Z'], np.less_equal(x, y))
13121312
# test with broadcast
@@ -1319,6 +1319,22 @@ def test_less_or_equal(self):
13191319
np.finfo(np.float16).max, shape2).astype(np.float16)
13201320
output = run_node(node_def, [x, y])
13211321
np.testing.assert_equal(output['Z'], np.less_equal(x, y))
1322+
# test data types that are not natively supported by Tensorflow
1323+
x = self._get_rnd_int(
1324+
np.iinfo(np.uint32).min,
1325+
np.iinfo(np.uint32).max, shape, np.uint32)
1326+
y = self._get_rnd_int(
1327+
np.iinfo(np.uint32).min,
1328+
np.iinfo(np.uint32).max, shape, np.uint32)
1329+
output = run_node(node_def, [x, y])
1330+
np.testing.assert_equal(output['Z'], np.less_equal(x, y))
1331+
x = self._get_rnd_int(
1332+
np.iinfo(np.uint64).min,
1333+
np.iinfo(np.uint64).max, shape, np.uint64)
1334+
y = self._get_rnd_int(
1335+
np.iinfo(np.uint64).min,
1336+
np.iinfo(np.uint64).max, shape, np.uint64)
1337+
self.assertRaises(RuntimeError, run_node, node_def, [x, y])
13221338

13231339
def test_lp_normalization(self):
13241340
for ordr in range(1, 3):
@@ -1627,12 +1643,12 @@ def test_matmul(self):
16271643
a = self._get_rnd_float32(shape=[5, 6])
16281644
b = self._get_rnd_float32(shape=[6, 5])
16291645
output = run_node(node_def, [a, b])
1630-
np.testing.assert_almost_equal(output["Y"], np.matmul(a, b))
1646+
np.testing.assert_allclose(output["Y"], np.matmul(a, b), rtol=1e-6, atol=1e-6)
16311647
# test data types that are not natively supported by Tensorflow
16321648
a = self._get_rnd_int(0, 1000, [10, 10], np.uint32)
16331649
b = self._get_rnd_int(0, 1000, [10, 10], np.uint32)
16341650
output = run_node(node_def, [a, b])
1635-
np.testing.assert_almost_equal(output["Y"], np.matmul(a, b))
1651+
np.testing.assert_allclose(output["Y"], np.matmul(a, b), rtol=1e-6, atol=1e-6)
16361652
# sys_config.auto_cast=False and a or b dtype=uint64 should throw exception
16371653
self.assertRaises(
16381654
RuntimeError, run_node, node_def,

0 commit comments

Comments
 (0)