Skip to content

Commit e44bba1

Browse files
authored
Fix: don't call round() on float images for ResizeV2 (#7669)
1 parent 906c2e9 commit e44bba1

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

test/test_transforms_v2_functional.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,3 +1395,13 @@ def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwarg
13951395
assert expected_stride == output_stride, error_msg_fn("")
13961396
else:
13971397
assert False, error_msg_fn("")
1398+
1399+
1400+
def test_resize_float16_no_rounding():
1401+
# Make sure Resize() doesn't round float16 images
1402+
# Non-regression test for https://github.com/pytorch/vision/issues/7667
1403+
1404+
img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
1405+
out = F.resize(img, size=(10, 10))
1406+
assert out.dtype == torch.float16
1407+
assert (out.round() - out).sum() > 0

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def resize_image_tensor(
228228
if need_cast:
229229
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
230230
image = image.clamp_(min=0, max=255)
231-
image = image.round_().to(dtype=dtype)
231+
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
232+
image = image.round_()
233+
image = image.to(dtype=dtype)
232234

233235
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
234236

0 commit comments

Comments
 (0)