Skip to content

Commit d622040

Browse files
authored
Pool fix (onnx#10)
* Misc CNN fixes and VGG tests pass * script to download all models * fix node test * all models in the script * relax batch_norm test * wip * more fixes * fix tests * fix model test * fix pooling * only test shufflenet for classification
1 parent a5990b3 commit d622040

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

onnx_tf/backend.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class TensorflowBackend(Backend):
207207

208208
@classmethod
209209
def guess_tf_pad(cls, pads):
210-
tf_pad = "VALID" if pads == None or pads[-1] == 0 else "SAME"
210+
tf_pad = "VALID" if pads == None or pads[-1] == 0 or (pads[0] != pads[2]) else "SAME"
211211
warnings.warn("Unsupported pads attribute by Tensorflow in "
212212
"pool operator. Your padding is {}, we guess "
213213
"you want {} padding.".format(str(pads), tf_pad),
@@ -425,7 +425,7 @@ def handle_arg_min(cls, node, input_dict):
425425
return [tf.argmin(data, axis=axis)]
426426

427427
@classmethod
428-
def _pool(cls, node, input_dict, pool_func):
428+
def _pool(cls, node, input_dict, pool_func, guess_or_manual_pad):
429429
x = input_dict[node.inputs[0]]
430430
x_rank = len(x.get_shape())
431431

@@ -435,22 +435,38 @@ def _pool(cls, node, input_dict, pool_func):
435435
kernel_shape = node.attrs["kernel_shape"]
436436
strides = node.attrs["strides"]
437437

438+
# By default, do not pad
439+
pad = "VALID"
438440
if "pads" in node.attrs.keys():
439-
x = cls.get_padding_as_op(x, node.attrs["pads"])
441+
if (guess_or_manual_pad == 0):
442+
pad = cls.guess_tf_pad(node.attrs["pads"])
443+
else:
444+
x = cls.get_padding_as_op(x, node.attrs["pads"])
445+
pad = "VALID"
440446

441447
if support_cuda:
442-
pooled = pool_func(x, [1, 1] + kernel_shape, [1, 1] + strides, "VALID",
448+
pooled = pool_func(x, [1, 1] + kernel_shape, [1, 1] + strides, pad,
443449
data_format=data_format)
444450
else:
445451
x = tf.transpose(x, perm=[0, 2, 3, 1])
446-
pooled = pool_func(x, [1] + kernel_shape + [1], [1] + strides + [1], "VALID",
452+
pooled = pool_func(x, [1] + kernel_shape + [1], [1] + strides + [1], pad,
447453
data_format=data_format)
448454
pooled = tf.transpose(pooled, perm=[0, 3, 1, 2])
449455
return [pooled]
450456

451457
@classmethod
452458
def handle_average_pool(cls, node, input_dict):
453-
return cls._pool(node, input_dict, tf.nn.avg_pool)
459+
spatial_dim = list(input_dict[node.inputs[0]].get_shape()[2:])
460+
kernel_shape = node.attrs.get("kernel_shape", [])
461+
global_pool = True
462+
for i in range(len(spatial_dim)):
463+
global_pool = global_pool and (spatial_dim[i] < kernel_shape[i])
464+
465+
if global_pool:
466+
return cls.handle_global_average_pool(node, input_dict)
467+
468+
# 0 = guess padding
469+
return cls._pool(node, input_dict, tf.nn.avg_pool, 0)
454470

455471
@classmethod
456472
def handle_batch_normalization(cls, node, input_dict):
@@ -703,7 +719,8 @@ def handle_max(cls, node, input_dict):
703719

704720
@classmethod
705721
def handle_max_pool(cls, node, input_dict):
706-
return cls._pool(node, input_dict, tf.nn.max_pool)
722+
# 1 = pad manually
723+
return cls._pool(node, input_dict, tf.nn.max_pool, 1)
707724

708725
@classmethod
709726
def handle_min(cls, node, input_dict):

test/download_model.sh

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
mkdir -p ../../onnx_models/
2+
13
wget https://s3.amazonaws.com/download.onnx/models/bvlc_alexnet.tar.gz --directory-prefix=../../onnx_models/
24
pushd ../../onnx_models/ && tar -xzf bvlc_alexnet.tar.gz && popd
35

test/test_model_large.py

-2
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ def test_resnet50(self):
5050
_test_nn("resnet50", "gpu_0/softmax_1")
5151

5252
def test_inception_v1(self):
53-
return
5453
_test_nn("inception_v1", "prob_1")
5554

5655
def test_inception_v2(self):
57-
return
5856
_test_nn("inception_v2", "prob_1")
5957

6058
if __name__ == '__main__':

0 commit comments

Comments
 (0)