Skip to content

Commit 095b51b

Browse files
authored
Fix Resize issue (onnx#702)
1 parent 44c0927 commit 095b51b

File tree

2 files changed

+139
-193
lines changed

2 files changed

+139
-193
lines changed

onnx_tf/handlers/backend/resize.py

Lines changed: 79 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -120,46 +120,35 @@ def args_check(cls, node, **kwargs):
120120

121121
@classmethod
122122
def version_10(cls, node, **kwargs):
123+
# x, roi and scales are all in NCHW format
123124
x = kwargs["tensor_dict"][node.inputs[0]]
124125
x_shape = tf_shape(x)
125126
scales = kwargs["tensor_dict"][node.inputs[1]]
126127

127-
n_in_scales_is_one = tf.equal(scales[0], 1)
128-
c_in_scales_is_one = tf.logical_or(tf.equal(scales[1], 1),
129-
tf.equal(scales[3], 1))
130-
assert_n_c_in_scales_are_ones = tf.Assert(
131-
tf.logical_and(n_in_scales_is_one, c_in_scales_is_one), [scales])
128+
h_w_scale = scales[2:]
129+
h_w_shape = x_shape[2:]
130+
new_h_w_shape = tf.cast(h_w_scale * tf.cast(h_w_shape, scales.dtype),
131+
tf.int32)
132132

133-
with tf.control_dependencies([assert_n_c_in_scales_are_ones]):
134-
x_in_NCHW_format = tf.equal(scales[1], 1)
135-
h_w_scale = tf.where(x_in_NCHW_format, scales[2:], scales[1:3])
136-
h_w_shape = tf.where(x_in_NCHW_format, x_shape[2:], x_shape[1:3])
137-
new_h_w_shape = tf.cast(h_w_scale * tf.cast(h_w_shape, scales.dtype),
138-
tf.int32)
139-
140-
mode = node.attrs.get("mode", "nearest")
141-
if mode.lower() == "linear":
142-
mode = tf.image.ResizeMethod.BILINEAR
143-
else:
144-
mode = tf.image.ResizeMethod.NEAREST_NEIGHBOR
145-
146-
def process_NCHW_format(x):
147-
x_t = tf.transpose(x, perm=[0, 2, 3, 1])
148-
y = tf.image.resize(x_t, size=new_h_w_shape, method=mode)
149-
y_t = tf.transpose(y, perm=[0, 3, 1, 2])
150-
return y_t
151-
152-
def process_NHWC_format(x):
153-
y = tf.image.resize(x, size=new_h_w_shape, method=mode)
154-
return y
133+
mode = node.attrs.get("mode", "nearest")
134+
if mode.lower() == "linear":
135+
mode = tf.image.ResizeMethod.BILINEAR
136+
else:
137+
mode = tf.image.ResizeMethod.NEAREST_NEIGHBOR
155138

156-
output = tf.cond(x_in_NCHW_format, lambda: process_NCHW_format(x),
157-
lambda: process_NHWC_format(x))
139+
# The input image is in NCHW format. But tf.image.resize only
140+
# support channel last data format. Therefore need to transpose
141+
# to NHWC format first then process resize and then transpose
142+
# back to NCHW format.
143+
x_t = tf.transpose(x, perm=[0, 2, 3, 1])
144+
y = tf.image.resize(x_t, size=new_h_w_shape, method=mode)
145+
output = tf.transpose(y, perm=[0, 3, 1, 2])
158146

159-
return [output]
147+
return [output]
160148

161149
@classmethod
162150
def version_11(cls, node, **kwargs):
151+
# x, roi, scales and sizes are all in NCHW format
163152
tensor_dict = kwargs["tensor_dict"]
164153
x = tensor_dict[node.inputs[0]]
165154
x_shape = tf_shape(x)
@@ -172,99 +161,63 @@ def version_11(cls, node, **kwargs):
172161
extrapolation_value = node.attrs.get("extrapolation_value", 0.0)
173162
mode = node.attrs.get("mode", "nearest")
174163

175-
param = scales if len(node.inputs) == 3 else sizes
176-
n_in_param_is_one = tf.equal(param[0], 1)
177-
c_in_param_is_one = tf.logical_or(tf.equal(param[1], 1),
178-
tf.equal(param[3], 1))
179-
assert_n_c_in_param_are_ones = tf.Assert(
180-
tf.logical_and(n_in_param_is_one, c_in_param_is_one), [param])
181-
182-
with tf.control_dependencies([assert_n_c_in_param_are_ones]):
183-
if mode.lower() == "linear":
184-
mode = tf.image.ResizeMethod.BILINEAR
185-
tf_resize = tf.compat.v1.image.resize_bilinear
186-
elif mode.lower() == "cubic":
187-
mode = tf.image.ResizeMethod.BICUBIC
188-
tf_resize = tf.compat.v1.image.resize_bicubic
189-
else:
190-
mode = tf.image.ResizeMethod.NEAREST_NEIGHBOR
191-
tf_resize = tf.compat.v1.image.resize_nearest_neighbor
192-
193-
x_in_NCHW_format = tf.equal(param[1], 1)
194-
195-
if len(node.inputs) == 3: # only scales is defined
196-
h_w_scale = tf.where(x_in_NCHW_format, scales[2:], scales[1:3])
197-
h_w_shape = tf.where(x_in_NCHW_format, x_shape[2:], x_shape[1:3])
198-
new_size = tf.cast(h_w_scale * tf.cast(h_w_shape, scales.dtype),
199-
tf.int32)
200-
else: # sizes is defined
201-
# The number of elements of 'sizes' should be the same as the rank of input 'X'
202-
sizes.set_shape(x_shape.shape)
203-
new_size = tf.cast(tf.where(x_in_NCHW_format, sizes[2:], sizes[1:3]),
204-
tf.int32)
205-
# Tensorflow require the shape of "size" in the "tf.image.resize" must be known at
206-
# graph creation time. However in the dynamic shape situation, the shape of "new_size"
207-
# will be "None", the actual shape can only be determine at runtime. But we know
208-
# "new_size" should always contain [h, w], therefore the shape must be 2.
209-
new_size.set_shape([2])
210-
211-
def get_NCHW_boxes():
212-
indices = []
213-
x_rank = len(x.get_shape())
214-
for i in range(2, x_rank):
215-
indices.insert(i - 2, i)
216-
indices.insert(i, i + x_rank)
217-
return tf.expand_dims(tf.gather(roi, indices, axis=0), 0)
218-
219-
def get_NHWC_boxes():
220-
indices = []
221-
x_rank = len(x.get_shape())
222-
for i in range(1, x_rank - 1):
223-
indices.insert(i - 1, i)
224-
indices.insert(i + 1, i + x_rank)
225-
return tf.expand_dims(tf.gather(roi, indices, axis=0), 0)
226-
227-
box_indices = tf.cast(tf.range(0, x_shape[0]), dtype=tf.int32)
228-
229-
def process_NCHW_format():
230-
x_t = tf.transpose(x, perm=[0, 2, 3, 1])
231-
if coordinate_transformation_mode == "tf_crop_and_resize":
232-
boxes = get_NCHW_boxes()
233-
y = tf.image.crop_and_resize(x_t, boxes, box_indices, new_size, mode,
234-
extrapolation_value)
235-
elif coordinate_transformation_mode == "align_corners":
236-
y = tf_resize(x_t,
237-
size=new_size,
238-
align_corners=True,
239-
half_pixel_centers=False)
240-
elif coordinate_transformation_mode == "asymmetric":
241-
y = tf_resize(x_t,
242-
size=new_size,
243-
align_corners=False,
244-
half_pixel_centers=False)
245-
else: # half_pixel or tf_half_pixel_for_nn
246-
y = tf.image.resize(x_t, size=new_size, method=mode)
247-
return tf.transpose(y, perm=[0, 3, 1, 2])
248-
249-
def process_NHWC_format():
250-
if coordinate_transformation_mode == "tf_crop_and_resize":
251-
boxes = get_NHWC_boxes()
252-
return tf.image.crop_and_resize(x, boxes, box_indices, new_size, mode,
253-
extrapolation_value)
254-
elif coordinate_transformation_mode == "align_corners":
255-
return tf_resize(x,
256-
size=new_size,
257-
align_corners=True,
258-
half_pixel_centers=False)
259-
elif coordinate_transformation_mode == "asymmetric":
260-
return tf_resize(x,
261-
size=new_size,
262-
align_corners=False,
263-
half_pixel_centers=False)
264-
else: # half_pixel or tf_half_pixel_for_nn
265-
return tf.image.resize(x, size=new_size, method=mode)
266-
267-
output = tf.cond(x_in_NCHW_format, process_NCHW_format,
268-
process_NHWC_format)
269-
270-
return [output]
164+
if mode.lower() == "linear":
165+
mode = tf.image.ResizeMethod.BILINEAR
166+
tf_resize = tf.compat.v1.image.resize_bilinear
167+
elif mode.lower() == "cubic":
168+
mode = tf.image.ResizeMethod.BICUBIC
169+
tf_resize = tf.compat.v1.image.resize_bicubic
170+
else:
171+
mode = tf.image.ResizeMethod.NEAREST_NEIGHBOR
172+
tf_resize = tf.compat.v1.image.resize_nearest_neighbor
173+
174+
if len(node.inputs) == 3: # only scales is defined
175+
h_w_scale = scales[2:]
176+
h_w_shape = x_shape[2:]
177+
new_size = tf.cast(h_w_scale * tf.cast(h_w_shape, scales.dtype),
178+
tf.int32)
179+
else: # sizes is defined
180+
# The number of elements of 'sizes' should be the same as the rank of input 'X'
181+
sizes.set_shape(x_shape.shape)
182+
new_size = tf.cast(sizes[2:], tf.int32)
183+
# Tensorflow require the shape of "size" in the "tf.image.resize" must be known at
184+
# graph creation time. However in the dynamic shape situation, the shape of "new_size"
185+
# will be "None", the actual shape can only be determine at runtime. But we know
186+
# "new_size" should always contain [h, w], therefore the shape must be 2.
187+
new_size.set_shape([2])
188+
189+
# get boxes for crop
190+
indices = []
191+
x_rank = len(x.get_shape())
192+
for i in range(2, x_rank):
193+
indices.insert(i - 2, i)
194+
indices.insert(i, i + x_rank)
195+
boxes = tf.expand_dims(tf.gather(roi, indices, axis=0), 0)
196+
197+
# get box_indices for crop
198+
box_indices = tf.cast(tf.range(0, x_shape[0]), dtype=tf.int32)
199+
200+
# The input image is in NCHW format. But tf.image.crop_and_resize,
201+
# tf.image.resize and tf.compat.v1.image.resize_xx only support
202+
# channel last data format. Therefore need to transpose to NHWC
203+
# formar first then process resize and then transpose back to
204+
# NCHW format.
205+
x_t = tf.transpose(x, perm=[0, 2, 3, 1])
206+
if coordinate_transformation_mode == "tf_crop_and_resize":
207+
y = tf.image.crop_and_resize(x_t, boxes, box_indices, new_size, mode,
208+
extrapolation_value)
209+
elif coordinate_transformation_mode == "align_corners":
210+
y = tf_resize(x_t,
211+
size=new_size,
212+
align_corners=True,
213+
half_pixel_centers=False)
214+
elif coordinate_transformation_mode == "asymmetric":
215+
y = tf_resize(x_t,
216+
size=new_size,
217+
align_corners=False,
218+
half_pixel_centers=False)
219+
else: # half_pixel or tf_half_pixel_for_nn
220+
y = tf.image.resize(x_t, size=new_size, method=mode)
221+
output = tf.transpose(y, perm=[0, 3, 1, 2])
222+
223+
return [output]

0 commit comments

Comments
 (0)