Skip to content

Commit a90e584

Browse files
authored
[CHERRYPICK] PIL fill len 1 seq / float fill for int images (#7951)
1 parent eab7cfb commit a90e584

File tree

4 files changed

+25
-50
lines changed

4 files changed

+25
-50
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,12 @@ def adapt_fill(value, *, dtype):
309309
return value
310310

311311
max_value = get_max_value(dtype)
312+
value_type = float if dtype.is_floating_point else int
312313

313314
if isinstance(value, (int, float)):
314-
return type(value)(value * max_value)
315+
return value_type(value * max_value)
315316
elif isinstance(value, (list, tuple)):
316-
return type(value)(type(v)(v * max_value) for v in value)
317+
return type(value)(value_type(v * max_value) for v in value)
317318
else:
318319
raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'.")
319320

@@ -414,6 +415,10 @@ def affine_bounding_boxes(bounding_boxes):
414415
)
415416

416417

418+
# turns all warnings into errors for this module
419+
pytestmark = pytest.mark.filterwarnings("error")
420+
421+
417422
class TestResize:
418423
INPUT_SIZE = (17, 11)
419424
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
@@ -2575,18 +2580,19 @@ def test_functional_image_correctness(self, kwargs):
25752580
def test_transform(self, param, value, make_input):
25762581
input = make_input(self.INPUT_SIZE)
25772582

2578-
kwargs = {param: value}
25792583
if param == "fill":
2580-
# 1. size is required
2581-
# 2. the fill parameter only has an affect if we need padding
2582-
kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]
2583-
2584-
if isinstance(input, PIL.Image.Image) and isinstance(value, (tuple, list)) and len(value) == 1:
2585-
pytest.xfail("F._pad_image_pil does not support sequences of length 1 for fill.")
2586-
25872584
if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
25882585
pytest.skip("F.pad_mask doesn't support non-scalar fill.")
25892586

2587+
kwargs = dict(
2588+
# 1. size is required
2589+
# 2. the fill parameter only has an affect if we need padding
2590+
size=[s + 4 for s in self.INPUT_SIZE],
2591+
fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
2592+
)
2593+
else:
2594+
kwargs = {param: value}
2595+
25902596
check_transform(
25912597
transforms.RandomCrop(**kwargs, pad_if_needed=True),
25922598
input,

test/transforms_v2_dispatcher_infos.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import collections.abc
2-
31
import pytest
42
import torchvision.transforms.v2.functional as F
53
from torchvision import tv_tensors
@@ -112,32 +110,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
112110
multi_crop_skips.append(skip_dispatch_tv_tensor)
113111

114112

115-
def xfails_pil(reason, *, condition=None):
116-
return [
117-
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
118-
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
119-
]
120-
121-
122-
def fill_sequence_needs_broadcast(args_kwargs):
123-
(image_loader, *_), kwargs = args_kwargs
124-
try:
125-
fill = kwargs["fill"]
126-
except KeyError:
127-
return False
128-
129-
if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
130-
return False
131-
132-
return image_loader.num_channels > 1
133-
134-
135-
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
136-
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
137-
condition=fill_sequence_needs_broadcast,
138-
)
139-
140-
141113
DISPATCHER_INFOS = [
142114
DispatcherInfo(
143115
F.resized_crop,
@@ -159,14 +131,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
159131
},
160132
pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
161133
test_marks=[
162-
*xfails_pil(
163-
reason=(
164-
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
165-
"`padding_mode='constant'`, if the number of color channels is larger."
166-
),
167-
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
168-
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
169-
),
170134
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
171135
xfail_jit_python_scalar_arg("padding"),
172136
],
@@ -181,7 +145,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
181145
},
182146
pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
183147
test_marks=[
184-
*xfails_pil_if_fill_sequence_needs_broadcast,
185148
xfail_jit_python_scalar_arg("fill"),
186149
],
187150
),

torchvision/transforms/_functional_pil.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,13 @@ def _parse_fill(
264264
if isinstance(fill, (int, float)) and num_channels > 1:
265265
fill = tuple([fill] * num_channels)
266266
if isinstance(fill, (list, tuple)):
267-
if len(fill) != num_channels:
267+
if len(fill) == 1:
268+
fill = fill * num_channels
269+
elif len(fill) != num_channels:
268270
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
269271
raise ValueError(msg.format(len(fill), num_channels))
270272

271-
fill = tuple(fill)
273+
fill = tuple(fill) # type: ignore[arg-type]
272274

273275
if img.mode != "F":
274276
if isinstance(fill, (list, tuple)):

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,11 @@ def _pad_with_vector_fill(
12351235

12361236
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
12371237
left, right, top, bottom = torch_padding
1238-
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
1238+
1239+
# We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
1240+
# float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
1241+
# value.
1242+
fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)
12391243

12401244
if top > 0:
12411245
output[..., :top, :] = fill

0 commit comments

Comments
 (0)