Skip to content

Commit c6784d4

Browse files
authored
Backport dynamic shape support for ConvTranspose and BatchNormalization (#724)
1. Backport dynamic shape support for ConvTranspose and BatchNormalization to tf-1.x branch 2. Backport convTranspose testcases to tf-1.x branch
1 parent 425cece commit c6784d4

File tree

5 files changed

+182
-21
lines changed

5 files changed

+182
-21
lines changed

onnx_tf/handlers/backend/batch_normalization.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def _common(cls, node, **kwargs):
2929

3030
params_shape_broadcast = list([1, x_shape[1]] +
3131
[1 for _ in range(2, x_rank)])
32+
# process unknown channel shape
33+
if params_shape_broadcast[1] is None:
34+
params_shape_broadcast[1] = tf.shape(x)[1]
35+
params_shape_broadcast = tf.stack(params_shape_broadcast)
3236

3337
total_num_dim = len(x.get_shape())
3438
scale = tf.reshape(tensor_dict[node.inputs[1]], params_shape_broadcast)

onnx_tf/handlers/backend/conv_mixin.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ def conv(cls, node, input_dict, transpose=False):
127127
]
128128
conv_output_shape.insert(compute_c_idx, weights_shape[-2])
129129

130+
def handle_dynamic_batch_size(output_shape, batch_idx):
131+
output_shape[batch_idx] = tf.shape(x)[batch_idx]
132+
return tf.stack(output_shape)
133+
134+
# process dynamic batch size
135+
if conv_output_shape[storage_format.find("N")] is None:
136+
batch_idx = storage_format.find("N")
137+
conv_output_shape = handle_dynamic_batch_size(conv_output_shape,
138+
batch_idx)
139+
130140
# make strides to match input rank
131141
strides_full = [1] + strides
132142
strides_full.insert(compute_c_idx, 1)
@@ -169,6 +179,12 @@ def conv(cls, node, input_dict, transpose=False):
169179
pads[spatial_format.find(d) + spatial_size]
170180
for d, s in zip(compute_format, conv_rs_shape)
171181
]
182+
183+
# process dynamic batch size
184+
if size[compute_format.find("N")] is None:
185+
batch_idx = compute_format.find("N")
186+
size = handle_dynamic_batch_size(size, batch_idx)
187+
172188
conv_rs = tf.slice(conv_rs, begin=begin, size=size)
173189

174190
convolved.append(conv_rs)
@@ -190,6 +206,12 @@ def conv(cls, node, input_dict, transpose=False):
190206
]
191207
conv_output_shape.insert(compute_c_idx, weights_shape[-2])
192208

209+
# process dynamic batch size
210+
if conv_output_shape[storage_format.find("N")] is None:
211+
batch_idx = storage_format.find("N")
212+
conv_output_shape = handle_dynamic_batch_size(conv_output_shape,
213+
batch_idx)
214+
193215
# make strides to match input rank
194216
strides_full = [1] + strides
195217
strides_full.insert(compute_c_idx, 1)

onnx_tf/handlers/backend/upsample.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from onnx_tf.handlers.handler import partial_support
1010
from onnx_tf.handlers.handler import ps_description
1111
from onnx_tf.handlers.handler import tf_func
12+
from onnx_tf.common.tf_helper import tf_shape
1213

1314

1415
@onnx_op("Upsample")
@@ -54,7 +55,7 @@ def version_7(cls, node, **kwargs):
5455
@classmethod
5556
def version_9(cls, node, **kwargs):
5657
x = kwargs["tensor_dict"][node.inputs[0]]
57-
x_shape = x.get_shape().as_list()
58+
x_shape = tf_shape(x)
5859
attrs = copy.deepcopy(node.attrs)
5960
scales = kwargs["tensor_dict"][node.inputs[1]]
6061

@@ -65,7 +66,8 @@ def version_9(cls, node, **kwargs):
6566
with tf.control_dependencies([assert_n_c_scale_is_one]):
6667
h_w_scale = scales[2:]
6768
h_w_shape = x_shape[2:]
68-
new_h_w_shape = tf.cast(h_w_scale * h_w_shape, tf.int32)
69+
new_h_w_shape = tf.cast(h_w_scale * tf.cast(h_w_shape, scales.dtype),
70+
tf.int32)
6971

7072
mode = attrs.get("mode", "nearest")
7173
if mode.lower() == "bilinear" or mode.lower() == "linear":

test/backend/test_dynamic_shape.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TestDynamicShape(unittest.TestCase):
2323

2424
def _get_rnd_float32(self, low=-1.0, high=1.0, shape=None):
2525
output = np.random.uniform(low, high, shape)
26-
if shape == None:
26+
if shape is None:
2727
return np.float32(output)
2828
else:
2929
return output.astype(np.float32)
@@ -87,6 +87,100 @@ def test_arg_min(self):
8787
expected_output = x.shape[axis] - expected_output - 1
8888
np.testing.assert_almost_equal(output['Y'], expected_output)
8989

90+
def _batch_normalization(self, x, mean, variance, bias, scale,
91+
variance_epsilon):
92+
inv = np.reciprocal(np.sqrt(variance + variance_epsilon))
93+
if scale is not None:
94+
inv *= scale
95+
return x * inv + (bias - mean * inv if bias is not None else -mean * inv)
96+
97+
def test_batch_normalization(self):
98+
if legacy_opset_pre_ver(6):
99+
raise unittest.SkipTest("Backend doesn't support consumed flag")
100+
node_def = helper.make_node("BatchNormalization",
101+
["X", "scale", "bias", "mean", "var"], ["Y"],
102+
epsilon=0.001)
103+
graph_def = helper.make_graph(
104+
[node_def],
105+
name="test_unknown_shape",
106+
inputs=[
107+
helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, None, None, None]),
108+
helper.make_tensor_value_info("scale", TensorProto.FLOAT, [None]),
109+
helper.make_tensor_value_info("bias", TensorProto.FLOAT, [None]),
110+
helper.make_tensor_value_info("mean", TensorProto.FLOAT, [None]),
111+
helper.make_tensor_value_info("var", TensorProto.FLOAT, [None])
112+
],
113+
outputs=[
114+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None, None])
115+
])
116+
x_shape = [3, 5, 4, 2]
117+
param_shape = [5]
118+
_param_shape = [1, 5, 1, 1]
119+
x = self._get_rnd_float32(0, 1, shape=x_shape)
120+
m = self._get_rnd_float32(0, 1, shape=param_shape)
121+
_m = m.reshape(_param_shape)
122+
v = self._get_rnd_float32(0, 1, shape=param_shape)
123+
_v = v.reshape(_param_shape)
124+
scale = self._get_rnd_float32(0, 1, shape=param_shape)
125+
_scale = scale.reshape(_param_shape)
126+
bias = self._get_rnd_float32(0, 1, shape=param_shape)
127+
_bias = bias.reshape(_param_shape)
128+
golden = self._batch_normalization(x, _m, _v, _bias, _scale, 0.001)
129+
tf_rep = onnx_graph_to_tensorflow_rep(graph_def)
130+
output = tf_rep.run({"X": x, "scale": scale, "bias": bias, "mean": m, "var": v})
131+
np.testing.assert_almost_equal(output["Y"], golden, decimal=5)
132+
133+
def test_conv_transpose(self):
134+
# test dynamic batch size on transpose of 2d convolution
135+
pads = [1, 1, 1, 1]
136+
x_shape = [1, 3, 4, 6]
137+
x = self._get_rnd_float32(shape=x_shape)
138+
weight_shape = [3, 5, 2, 2]
139+
weights = self._get_rnd_float32(shape=weight_shape)
140+
141+
node_def = helper.make_node("ConvTranspose", ["X", "weights"], ["Y"],
142+
pads=pads)
143+
graph_def = helper.make_graph(
144+
[node_def],
145+
name="test_unknown_shape",
146+
inputs=[
147+
helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, 3, 4, 6]),
148+
helper.make_tensor_value_info("weights", TensorProto.FLOAT, weight_shape)
149+
],
150+
outputs=[
151+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None, None])
152+
])
153+
154+
tf_rep = onnx_graph_to_tensorflow_rep(graph_def)
155+
output = tf_rep.run({"X": x, "weights": weights})
156+
157+
padh_left = weight_shape[2] - 1 - pads[0]
158+
padh_right = weight_shape[2] - 1 - pads[1]
159+
padw_left = weight_shape[3] - 1 - pads[2]
160+
padw_right = weight_shape[3] - 1 - pads[3]
161+
162+
kh = weight_shape[2]
163+
kw = weight_shape[3]
164+
outh = x_shape[2] + padh_right + padh_right - (kh - 1)
165+
outw = x_shape[3] + padw_right + padw_right - (kw - 1)
166+
167+
out_shape = [x_shape[0], weight_shape[1], outh, outw]
168+
169+
test_output = np.zeros(out_shape)
170+
for b in range(0, x_shape[0]):
171+
for m in range(0, weight_shape[1]):
172+
for c in range(0, x_shape[1]):
173+
for h in range(0, outh):
174+
for w in range(0, outw):
175+
for k1 in range(h, h + kh):
176+
for k2 in range(w, w + kw):
177+
if (k1 - padh_left >= 0 and k2 - padw_left >= 0):
178+
test_output[b][m][h][w] += x[b][c][k1 - padh_left][
179+
k2 - padw_left] * weights[c][m][kh + h - 1 -
180+
k1][kw + w - 1 - k2]
181+
182+
np.testing.assert_almost_equal(output["Y"], test_output, decimal=5)
183+
90184
def test_slice(self):
91185
# test case 1 with normal inputs
92186
axes = [0, 1, 2]

test/backend/test_node.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TestNode(unittest.TestCase):
2323

2424
def _get_rnd_float32(self, low=-1.0, high=1.0, shape=None):
2525
output = np.random.uniform(low, high, shape)
26-
if shape == None:
26+
if shape is None:
2727
return np.float32(output)
2828
else:
2929
return output.astype(np.float32)
@@ -460,29 +460,68 @@ def test_conv_integer(self):
460460
np.testing.assert_almost_equal(output["Y"], y)
461461

462462
def test_conv_transpose(self):
463-
# Fix test in the future.
464-
return
465-
device = "CUDA"
466-
if not supports_device(device):
467-
raise unittest.SkipTest(
468-
"Backend doesn't support device {}".format(device))
463+
device = "CUDA" if supports_device("CUDA") else "CPU"
464+
465+
pads = [1, 1]
469466
node_def = helper.make_node("ConvTranspose", ["X", "weights"], ["Y"],
470-
pads=[1, 1])
471-
x_shape = [1, 5, 4]
467+
pads=pads)
468+
x_shape = [1, 3, 4]
472469
x = self._get_rnd_float32(shape=x_shape)
473-
weight_shape = [5, 3, 2]
470+
weight_shape = [3, 5, 2]
474471
weights = self._get_rnd_float32(shape=weight_shape)
475472
output = run_node(node_def, [x, weights], device=device)
476-
out_shape = [x_shape[0], weight_shape[1], x_shape[2]]
473+
474+
padh_left = weight_shape[2]-1-pads[0]
475+
padh_right = weight_shape[2]-1-pads[1]
476+
kh = weight_shape[2]
477+
outh = x_shape[2] + padh_right + padh_right - (kh - 1)
478+
479+
out_shape = [x_shape[0], weight_shape[1], outh]
480+
477481
test_output = np.zeros(out_shape)
478482
for b in range(0, x_shape[0]):
479483
for m in range(0, weight_shape[1]):
480-
for h in range(0, x_shape[2]):
481-
v = 0
482-
for c in range(0, x_shape[1]):
483-
for k in range(h, min(h + weight_shape[2], x_shape[2])):
484-
v += x[b][c][k] * weights[c][m][k - h]
485-
test_output[b][m][h] = v
484+
for c in range(0, x_shape[1]):
485+
for h in range(0, outh):
486+
for k in range(h , h + kh):
487+
if (k - padh_left >= 0):
488+
test_output[b][m][h] += x[b][c][k-padh_left] * weights[c][m][kh+h-1-k]
489+
490+
np.testing.assert_almost_equal(output["Y"], test_output, decimal=5)
491+
492+
# test for spatial dimension of colnolution is 2
493+
pads = [1, 1, 1, 1]
494+
node_def = helper.make_node("ConvTranspose", ["X", "weights"], ["Y"],
495+
pads=pads)
496+
x_shape = [1, 3, 4, 6]
497+
x = self._get_rnd_float32(shape=x_shape)
498+
weight_shape = [3, 5, 2, 2]
499+
weights = self._get_rnd_float32(shape=weight_shape)
500+
output = run_node(node_def, [x, weights],device=device)
501+
502+
padh_left = weight_shape[2]-1-pads[0]
503+
padh_right = weight_shape[2]-1-pads[1]
504+
padw_left = weight_shape[3]-1-pads[2]
505+
padw_right = weight_shape[3]-1-pads[3]
506+
507+
kh = weight_shape[2]
508+
kw = weight_shape[3]
509+
outh = x_shape[2] + padh_right + padh_right - (kh - 1)
510+
outw = x_shape[3] + padw_right + padw_right - (kw - 1)
511+
512+
out_shape = [x_shape[0], weight_shape[1], outh, outw]
513+
514+
test_output = np.zeros(out_shape)
515+
for b in range(0, x_shape[0]):
516+
for m in range(0, weight_shape[1]):
517+
for c in range(0, x_shape[1]):
518+
for h in range(0, outh):
519+
for w in range(0, outw):
520+
for k1 in range(h , h + kh):
521+
for k2 in range(w , w + kw):
522+
if (k1 - padh_left >= 0 and k2 - padw_left >= 0):
523+
test_output[b][m][h][w] += x[b][c][k1-padh_left][k2-padw_left] * weights[c][m][kh+h-1-k1][kw+w-1-k2]
524+
486525
np.testing.assert_almost_equal(output["Y"], test_output, decimal=5)
487526

488527
def test_cosh(self):
@@ -1124,7 +1163,7 @@ def test_loop(self):
11241163
x_in = helper.make_tensor_value_info('x', TensorProto.INT32, [None])
11251164
y_in = helper.make_tensor_value_info('y', TensorProto.INT32, [None])
11261165

1127-
cond_out = helper.make_tensor_value_info('cond', TensorProto.STRING, [])
1166+
cond_out = helper.make_tensor_value_info('cond', TensorProto.BOOL, [])
11281167
new_cond_out = helper.make_tensor_value_info('new_cond', TensorProto.BOOL,
11291168
[])
11301169
sum1_out = helper.make_tensor_value_info('sum1', TensorProto.INT32, [None])

0 commit comments

Comments
 (0)