Skip to content

Commit 55b4efa

Browse files
committed
pad in forward
1 parent e89a31b commit 55b4efa

File tree

3 files changed

+99
-14
lines changed

3 files changed

+99
-14
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections.abc
22
import math
33
import warnings
4-
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
4+
from typing import Any, Dict, List, Union, Sequence, Tuple, cast, Literal
55

66
import PIL.Image
77
import torch
@@ -259,7 +259,14 @@ def apply_recursively(obj: Any) -> Any:
259259

260260

261261
class RandomCrop(Transform):
262-
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
262+
def __init__(
263+
self,
264+
size=Union[int, Sequence[int]],
265+
padding: Sequence[int] = None,
266+
pad_if_needed: bool = False,
267+
fill: Union[int, str, Sequence[int]] = 0,
268+
padding_mode: Union[str, Literal["constant", "edge", "reflect", "symmetric"]] = "constant",
269+
):
263270
super().__init__()
264271
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
265272

@@ -269,12 +276,9 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode
269276
self.padding_mode = padding_mode
270277

271278
def _get_params(self, sample: Any) -> Dict[str, Any]:
272-
273279
"""Get parameters for ``crop`` for a random crop.
274-
275280
Args:
276281
sample (PIL Image, Tensor or features.Image): Image to be cropped.
277-
278282
Returns:
279283
dict: Dict containing 'top', 'left', 'height', and 'width'
280284
"""
@@ -294,19 +298,51 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
294298
return dict(top=i, left=j, height=th, width=tw)
295299

296300
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
297-
298301
if isinstance(input, features.Image):
299-
output = F.random_crop_image_tensor(input, **params, padding=self.padding)
300-
input = features.Image.new_like(input, output)
302+
output = F.crop_image_tensor(input, **params)
303+
return features.Image.new_like(input, output)
301304
elif isinstance(input, PIL.Image.Image):
302-
input = F.random_crop_image_pil(input, **params)
305+
return F.crop_image_pil(input, **params)
306+
elif is_simple_tensor(input):
307+
return F.crop_image_tensor(input, **params)
303308
else:
304-
input = F.random_crop_image_tensor(input, **params)
305-
306-
return input
309+
return input
307310

308311
def forward(self, *inputs: Any) -> Any:
309312
sample = inputs if len(inputs) > 1 else inputs[0]
310313
if has_any(sample, features.BoundingBox, features.SegmentationMask):
311314
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
315+
316+
if isinstance(sample, features.Image):
317+
output = F.random_pad_image_tensor(
318+
sample,
319+
output_size=self.size,
320+
image_size=get_image_dimensions(sample),
321+
padding=self.padding,
322+
pad_if_needed=self.pad_if_needed,
323+
fill=self.fill,
324+
padding_mode=self.padding_mode,
325+
)
326+
sample = features.Image.new_like(sample, output)
327+
elif isinstance(sample, PIL.Image.Image):
328+
sample = F.random_pad_image_pil(
329+
sample,
330+
output_size=self.size,
331+
image_size=get_image_dimensions(sample),
332+
padding=self.padding,
333+
pad_if_needed=self.pad_if_needed,
334+
fill=self.fill,
335+
padding_mode=self.padding_mode,
336+
)
337+
elif is_simple_tensor(sample):
338+
sample = F.random_pad_image_tensor(
339+
sample,
340+
output_size=self.size,
341+
image_size=get_image_dimensions(sample),
342+
padding=self.padding,
343+
pad_if_needed=self.pad_if_needed,
344+
fill=self.fill,
345+
padding_mode=self.padding_mode,
346+
)
347+
312348
return super().forward(sample)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
center_crop_image_pil,
4949
resized_crop_image_tensor,
5050
resized_crop_image_pil,
51-
random_crop_image_tensor,
52-
random_crop_image_pil,
51+
random_pad_image_tensor,
52+
random_pad_image_pil,
5353
affine_image_tensor,
5454
affine_image_pil,
5555
rotate_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,3 +451,52 @@ def random_crop_image_pil(
451451
img = pad_image_pil(img, padding, fill, padding_mode)
452452

453453
return crop_image_pil(img, top, left, height, width)
454+
455+
456+
def random_pad_image_tensor(
457+
img: torch.Tensor,
458+
output_size: List[int],
459+
image_size: Tuple[int, int, int],
460+
padding: List[int] = None,
461+
pad_if_needed: bool = False,
462+
fill: int = 0,
463+
padding_mode: str = "constant",
464+
) -> torch.Tensor:
465+
_, height, width = image_size
466+
467+
if padding is not None:
468+
img = pad_image_tensor(img, padding, fill, padding_mode)
469+
# pad the width if needed
470+
if pad_if_needed and width < output_size[1]:
471+
padding = [output_size[1] - width, 0]
472+
img = pad_image_tensor(img, padding, fill, padding_mode)
473+
# pad the height if needed
474+
if pad_if_needed and height < output_size[0]:
475+
padding = [0, output_size[0] - height]
476+
img = pad_image_tensor(img, padding, fill, padding_mode)
477+
return img
478+
479+
480+
def random_pad_image_pil(
481+
img: PIL.Image.Image,
482+
output_size: List[int],
483+
image_size: Tuple[int, int, int],
484+
padding: List[int] = None,
485+
pad_if_needed: bool = False,
486+
fill: int = 0,
487+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
488+
) -> PIL.Image.Image:
489+
490+
_, height, width = image_size
491+
492+
if padding is not None:
493+
img = pad_image_pil(img, padding, fill, padding_mode)
494+
# pad the width if needed
495+
if pad_if_needed and width < output_size[1]:
496+
padding = [output_size[1] - width, 0]
497+
img = pad_image_pil(img, padding, fill, padding_mode)
498+
# pad the height if needed
499+
if pad_if_needed and height < output_size[0]:
500+
padding = [0, output_size[0] - height]
501+
img = pad_image_pil(img, padding, fill, padding_mode)
502+
return img

0 commit comments

Comments
 (0)