Skip to content

Commit 6685f45

Browse files
authored
fix dynamic shape inference in DepthToSpace (#880)
* fix dynamic shape inference in DepthToSpace Signed-off-by: Dom Miketa <[email protected]> * add unittest Signed-off-by: Dom Miketa <[email protected]> * formatting with yapf and pylint Signed-off-by: Dom Miketa <[email protected]> * format unittest Signed-off-by: Dom Miketa <[email protected]>
1 parent 3c358ad commit 6685f45

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

onnx_tf/handlers/backend/depth_to_space.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tensorflow as tf
44

55
from onnx_tf.common import get_data_format
6+
from onnx_tf.common.tf_helper import tf_shape
67
from onnx_tf.handlers.backend_handler import BackendHandler
78
from onnx_tf.handlers.handler import onnx_op
89
from onnx_tf.handlers.handler import tf_func
@@ -24,23 +25,27 @@ def version_1(cls, node, **kwargs):
2425
attrs = copy.deepcopy(node.attrs)
2526
attrs["data_format"] = storage_format
2627
return [
27-
cls.make_tensor_from_onnx_node(
28-
node, attrs=attrs, c_first_cuda_only=True, **kwargs)
28+
cls.make_tensor_from_onnx_node(node,
29+
attrs=attrs,
30+
c_first_cuda_only=True,
31+
**kwargs)
2932
]
3033

3134
@classmethod
3235
def _common(cls, node, **kwargs):
3336
x = kwargs["tensor_dict"][node.inputs[0]]
3437
x_rank = len(x.get_shape())
35-
storage_format, compute_format = get_data_format(x_rank)
38+
storage_format, _ = get_data_format(x_rank)
3639
attrs = copy.deepcopy(node.attrs)
3740
attrs["data_format"] = storage_format
3841
mode = attrs.get("mode", "DCR")
3942

4043
if mode == "CRD":
4144
# need native computation
4245
bsize = attrs.get("blocksize")
43-
batch, channel, height, width = x.shape
46+
x_shape = tf_shape(x)
47+
batch, channel = x_shape[0], x_shape[1]
48+
height, width = x_shape[2], x_shape[3]
4449
csize = channel // (bsize**2)
4550

4651
reshape_node = tf.reshape(x, [batch, csize, bsize, bsize, height, width])
@@ -50,11 +55,12 @@ def _common(cls, node, **kwargs):
5055
[batch, csize, height * bsize, width * bsize])
5156
]
5257

53-
else:
54-
return [
55-
cls.make_tensor_from_onnx_node(
56-
node, attrs=attrs, c_first_cuda_only=True, **kwargs)
57-
]
58+
return [
59+
cls.make_tensor_from_onnx_node(node,
60+
attrs=attrs,
61+
c_first_cuda_only=True,
62+
**kwargs)
63+
]
5864

5965
@classmethod
6066
def version_11(cls, node, **kwargs):

test/backend/test_dynamic_shape.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,37 @@ def test_conv_transpose(self):
248248

249249
np.testing.assert_almost_equal(tf_model_output[0], test_output, decimal=5)
250250

251+
def test_depth_to_space(self):
252+
b, c, h, w = shape = [2, 48, 5, 6]
253+
blocksize = 4
254+
x = self._get_rnd_float32(shape=shape)
255+
node_def = helper.make_node("DepthToSpace", ["X"], ["Y"],
256+
blocksize=blocksize,
257+
mode="DCR")
258+
graph_def = helper.make_graph(
259+
[node_def],
260+
name="test_unknown_shape",
261+
inputs=[
262+
helper.make_tensor_value_info("X", TensorProto.FLOAT,
263+
[None, None, None, None])
264+
],
265+
outputs=[
266+
helper.make_tensor_value_info("Y", TensorProto.FLOAT,
267+
[None, None, None, None])
268+
])
269+
tf_rep = onnx_graph_to_tensorflow_rep(graph_def)
270+
# export to tf.saved_model
271+
model_path = 'test_dynamic_shape/depth_to_space'
272+
tf_rep.export_graph(model_path)
273+
# load the saved_model back
274+
tf_model = tf.saved_model.load(model_path)
275+
# run the model
276+
tf_model_output = tf_model(X=x)
277+
tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w])
278+
tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2])
279+
y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize])
280+
np.testing.assert_almost_equal(tf_model_output[0], y)
281+
251282
def test_eye_like(self):
252283
if legacy_opset_pre_ver(9):
253284
raise unittest.SkipTest("ONNX version {} doesn't support EyeLike.".format(

0 commit comments

Comments
 (0)