Skip to content

Commit 86d0f1f

Browse files
committed
formatting with yapf and pylint
Signed-off-by: Dom Miketa <[email protected]>
1 parent 594f9ae commit 86d0f1f

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

onnx_tf/handlers/backend/depth_to_space.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tensorflow as tf
44

55
from onnx_tf.common import get_data_format
6+
from onnx_tf.common.tf_helper import tf_shape
67
from onnx_tf.handlers.backend_handler import BackendHandler
78
from onnx_tf.handlers.handler import onnx_op
89
from onnx_tf.handlers.handler import tf_func
@@ -24,24 +25,27 @@ def version_1(cls, node, **kwargs):
2425
attrs = copy.deepcopy(node.attrs)
2526
attrs["data_format"] = storage_format
2627
return [
27-
cls.make_tensor_from_onnx_node(
28-
node, attrs=attrs, c_first_cuda_only=True, **kwargs)
28+
cls.make_tensor_from_onnx_node(node,
29+
attrs=attrs,
30+
c_first_cuda_only=True,
31+
**kwargs)
2932
]
3033

3134
@classmethod
3235
def _common(cls, node, **kwargs):
3336
x = kwargs["tensor_dict"][node.inputs[0]]
3437
x_rank = len(x.get_shape())
35-
storage_format, compute_format = get_data_format(x_rank)
38+
storage_format, _ = get_data_format(x_rank)
3639
attrs = copy.deepcopy(node.attrs)
3740
attrs["data_format"] = storage_format
3841
mode = attrs.get("mode", "DCR")
3942

4043
if mode == "CRD":
4144
# need native computation
4245
bsize = attrs.get("blocksize")
43-
x_shape = tf.shape(x)
44-
batch, channel, height, width = x_shape[0], x_shape[1], x_shape[2], x_shape[3]
46+
x_shape = tf_shape(x)
47+
batch, channel = x_shape[0], x_shape[1]
48+
height, width = x_shape[2], x_shape[3]
4549
csize = channel // (bsize**2)
4650

4751
reshape_node = tf.reshape(x, [batch, csize, bsize, bsize, height, width])
@@ -51,11 +55,12 @@ def _common(cls, node, **kwargs):
5155
[batch, csize, height * bsize, width * bsize])
5256
]
5357

54-
else:
55-
return [
56-
cls.make_tensor_from_onnx_node(
57-
node, attrs=attrs, c_first_cuda_only=True, **kwargs)
58-
]
58+
return [
59+
cls.make_tensor_from_onnx_node(node,
60+
attrs=attrs,
61+
c_first_cuda_only=True,
62+
**kwargs)
63+
]
5964

6065
@classmethod
6166
def version_11(cls, node, **kwargs):

0 commit comments

Comments
 (0)