@@ -207,7 +207,7 @@ class TensorflowBackend(Backend):
207
207
208
208
@classmethod
209
209
def guess_tf_pad (cls , pads ):
210
- tf_pad = "VALID" if pads == None or pads [- 1 ] == 0 else "SAME"
210
+ tf_pad = "VALID" if pads == None or pads [- 1 ] == 0 or ( pads [ 0 ] != pads [ 2 ]) else "SAME"
211
211
warnings .warn ("Unsupported pads attribute by Tensorflow in "
212
212
"pool operator. Your padding is {}, we guess "
213
213
"you want {} padding." .format (str (pads ), tf_pad ),
@@ -425,7 +425,7 @@ def handle_arg_min(cls, node, input_dict):
425
425
return [tf .argmin (data , axis = axis )]
426
426
427
427
@classmethod
428
- def _pool (cls , node , input_dict , pool_func ):
428
+ def _pool (cls , node , input_dict , pool_func , guess_or_manual_pad ):
429
429
x = input_dict [node .inputs [0 ]]
430
430
x_rank = len (x .get_shape ())
431
431
@@ -435,22 +435,38 @@ def _pool(cls, node, input_dict, pool_func):
435
435
kernel_shape = node .attrs ["kernel_shape" ]
436
436
strides = node .attrs ["strides" ]
437
437
438
+ # By default, do not pad
439
+ pad = "VALID"
438
440
if "pads" in node .attrs .keys ():
439
- x = cls .get_padding_as_op (x , node .attrs ["pads" ])
441
+ if (guess_or_manual_pad == 0 ):
442
+ pad = cls .guess_tf_pad (node .attrs ["pads" ])
443
+ else :
444
+ x = cls .get_padding_as_op (x , node .attrs ["pads" ])
445
+ pad = "VALID"
440
446
441
447
if support_cuda :
442
- pooled = pool_func (x , [1 , 1 ] + kernel_shape , [1 , 1 ] + strides , "VALID" ,
448
+ pooled = pool_func (x , [1 , 1 ] + kernel_shape , [1 , 1 ] + strides , pad ,
443
449
data_format = data_format )
444
450
else :
445
451
x = tf .transpose (x , perm = [0 , 2 , 3 , 1 ])
446
- pooled = pool_func (x , [1 ] + kernel_shape + [1 ], [1 ] + strides + [1 ], "VALID" ,
452
+ pooled = pool_func (x , [1 ] + kernel_shape + [1 ], [1 ] + strides + [1 ], pad ,
447
453
data_format = data_format )
448
454
pooled = tf .transpose (pooled , perm = [0 , 3 , 1 , 2 ])
449
455
return [pooled ]
450
456
451
457
@classmethod
452
458
def handle_average_pool (cls , node , input_dict ):
453
- return cls ._pool (node , input_dict , tf .nn .avg_pool )
459
+ spatial_dim = list (input_dict [node .inputs [0 ]].get_shape ()[2 :])
460
+ kernel_shape = node .attrs .get ("kernel_shape" , [])
461
+ global_pool = True
462
+ for i in range (len (spatial_dim )):
463
+ global_pool = global_pool and (spatial_dim [i ] < kernel_shape [i ])
464
+
465
+ if global_pool :
466
+ return cls .handle_global_average_pool (node , input_dict )
467
+
468
+ # 0 = guess padding
469
+ return cls ._pool (node , input_dict , tf .nn .avg_pool , 0 )
454
470
455
471
@classmethod
456
472
def handle_batch_normalization (cls , node , input_dict ):
@@ -703,7 +719,8 @@ def handle_max(cls, node, input_dict):
703
719
704
720
@classmethod
705
721
def handle_max_pool (cls , node , input_dict ):
706
- return cls ._pool (node , input_dict , tf .nn .max_pool )
722
+ # 1 = pad manually
723
+ return cls ._pool (node , input_dict , tf .nn .max_pool , 1 )
707
724
708
725
@classmethod
709
726
def handle_min (cls , node , input_dict ):
0 commit comments