3
3
import tensorflow as tf
4
4
5
5
from onnx_tf .common import get_data_format
6
+ from onnx_tf .common .tf_helper import tf_shape
6
7
from onnx_tf .handlers .backend_handler import BackendHandler
7
8
from onnx_tf .handlers .handler import onnx_op
8
9
from onnx_tf .handlers .handler import tf_func
@@ -24,24 +25,27 @@ def version_1(cls, node, **kwargs):
24
25
attrs = copy .deepcopy (node .attrs )
25
26
attrs ["data_format" ] = storage_format
26
27
return [
27
- cls .make_tensor_from_onnx_node (
28
- node , attrs = attrs , c_first_cuda_only = True , ** kwargs )
28
+ cls .make_tensor_from_onnx_node (node ,
29
+ attrs = attrs ,
30
+ c_first_cuda_only = True ,
31
+ ** kwargs )
29
32
]
30
33
31
34
@classmethod
32
35
def _common (cls , node , ** kwargs ):
33
36
x = kwargs ["tensor_dict" ][node .inputs [0 ]]
34
37
x_rank = len (x .get_shape ())
35
- storage_format , compute_format = get_data_format (x_rank )
38
+ storage_format , _ = get_data_format (x_rank )
36
39
attrs = copy .deepcopy (node .attrs )
37
40
attrs ["data_format" ] = storage_format
38
41
mode = attrs .get ("mode" , "DCR" )
39
42
40
43
if mode == "CRD" :
41
44
# need native computation
42
45
bsize = attrs .get ("blocksize" )
43
- x_shape = tf .shape (x )
44
- batch , channel , height , width = x_shape [0 ], x_shape [1 ], x_shape [2 ], x_shape [3 ]
46
+ x_shape = tf_shape (x )
47
+ batch , channel = x_shape [0 ], x_shape [1 ]
48
+ height , width = x_shape [2 ], x_shape [3 ]
45
49
csize = channel // (bsize ** 2 )
46
50
47
51
reshape_node = tf .reshape (x , [batch , csize , bsize , bsize , height , width ])
@@ -51,11 +55,12 @@ def _common(cls, node, **kwargs):
51
55
[batch , csize , height * bsize , width * bsize ])
52
56
]
53
57
54
- else :
55
- return [
56
- cls .make_tensor_from_onnx_node (
57
- node , attrs = attrs , c_first_cuda_only = True , ** kwargs )
58
- ]
58
+ return [
59
+ cls .make_tensor_from_onnx_node (node ,
60
+ attrs = attrs ,
61
+ c_first_cuda_only = True ,
62
+ ** kwargs )
63
+ ]
59
64
60
65
@classmethod
61
66
def version_11 (cls , node , ** kwargs ):
0 commit comments