Skip to content

Commit 09f6b04

Browse files
committed
refactoring
1 parent 3cc1964 commit 09f6b04

File tree

2 files changed

+39
-36
lines changed

2 files changed

+39
-36
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections.abc
22
import math
33
import warnings
4-
from typing import Any, Dict, List, Union, Sequence, Tuple, cast, Literal
4+
from typing import Any, Dict, List, Union, Sequence, Tuple, cast, Literal, Optional
55

66
import PIL.Image
77
import torch
@@ -262,7 +262,7 @@ class RandomCrop(Transform):
262262
def __init__(
263263
self,
264264
size: Union[int, Sequence[int]],
265-
padding: Sequence[int],
265+
padding: Optional[Sequence[int]] = None,
266266
pad_if_needed: bool = False,
267267
fill: Union[int, str, Sequence[int]] = 0,
268268
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
@@ -275,15 +275,15 @@ def __init__(
275275
self.fill = fill
276276
self.padding_mode = padding_mode
277277

278-
def _get_params(self, sample: Any) -> Dict[str, Any]:
278+
def _get_crop_parameters(self, image: Any) -> Dict[str, Any]:
279279
"""Get parameters for ``crop`` for a random crop.
280280
Args:
281281
sample (PIL Image, Tensor or features.Image): Image to be cropped.
282282
Returns:
283283
dict: Dict containing 'top', 'left', 'height', and 'width'
284284
"""
285285

286-
_, h, w = get_image_dimensions(sample)
286+
_, h, w = get_image_dimensions(image)
287287

288288
th, tw = self.size
289289

@@ -298,51 +298,54 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
298298
return dict(top=i, left=j, height=th, width=tw)
299299

300300
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
301-
if isinstance(input, features.Image):
302-
output = F.crop_image_tensor(input, **params)
303-
return features.Image.new_like(input, output)
304-
elif isinstance(input, PIL.Image.Image):
305-
return F.crop_image_pil(input, **params)
306-
elif is_simple_tensor(input):
307-
return F.crop_image_tensor(input, **params)
308-
else:
309-
return input
310301

311-
def forward(self, *inputs: Any) -> Any:
312-
sample = inputs if len(inputs) > 1 else inputs[0]
313-
if has_any(sample, features.BoundingBox, features.SegmentationMask):
314-
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
315-
316-
if isinstance(sample, features.Image):
302+
if isinstance(input, features.Image):
317303
output = F.random_pad_image_tensor(
318-
sample,
304+
input,
319305
output_size=self.size,
320-
image_size=get_image_dimensions(sample),
306+
image_size=get_image_dimensions(input),
321307
padding=cast(List[int], tuple(self.padding)),
322308
pad_if_needed=self.pad_if_needed,
323309
fill=self.fill,
324310
padding_mode=self.padding_mode,
325311
)
326-
sample = features.Image.new_like(sample, output)
327-
elif isinstance(sample, PIL.Image.Image):
328-
sample = F.random_pad_image_pil(
329-
sample,
312+
input = features.Image.new_like(input, output)
313+
elif isinstance(input, PIL.Image.Image):
314+
input = F.random_pad_image_pil(
315+
input,
330316
output_size=self.size,
331-
image_size=get_image_dimensions(sample),
332-
padding=cast(List[int], tuple(self.padding)),
317+
image_size=get_image_dimensions(input),
318+
padding=self.padding,
333319
pad_if_needed=self.pad_if_needed,
334320
fill=self.fill,
335321
padding_mode=self.padding_mode,
336322
)
337-
elif is_simple_tensor(sample):
338-
sample = F.random_pad_image_tensor(
339-
sample,
323+
elif is_simple_tensor(input):
324+
input = F.random_pad_image_tensor(
325+
input,
340326
output_size=self.size,
341-
image_size=get_image_dimensions(sample),
342-
padding=cast(List[int], tuple(self.padding)),
327+
image_size=get_image_dimensions(input),
328+
padding=self.padding,
343329
pad_if_needed=self.pad_if_needed,
344330
fill=self.fill, # TODO: should be converted to number
345331
padding_mode=self.padding_mode,
346332
)
347333

334+
params.update(self._get_crop_parameters(input))
335+
336+
if isinstance(input, features.Image):
337+
output = F.crop_image_tensor(input, **params)
338+
return features.Image.new_like(input, output)
339+
elif isinstance(input, PIL.Image.Image):
340+
return F.crop_image_pil(input, **params)
341+
elif is_simple_tensor(input):
342+
return F.crop_image_tensor(input, **params)
343+
else:
344+
return input
345+
346+
def forward(self, *inputs: Any) -> Any:
347+
sample = inputs if len(inputs) > 1 else inputs[0]
348+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
349+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
350+
348351
return super().forward(sample)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def random_crop_image_tensor(
399399
height: int,
400400
width: int,
401401
size: List[int],
402-
padding: List[int] = None,
402+
padding: Optional[List[int]] = None,
403403
pad_if_needed: bool = False,
404404
fill: int = 0,
405405
padding_mode: str = "constant",
@@ -430,7 +430,7 @@ def random_crop_image_pil(
430430
height: int,
431431
width: int,
432432
size: List[int],
433-
padding: List[int] = None,
433+
padding: Optional[List[int]] = None,
434434
pad_if_needed: bool = False,
435435
fill: int = 0,
436436
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
@@ -457,7 +457,7 @@ def random_pad_image_tensor(
457457
img: torch.Tensor,
458458
output_size: List[int],
459459
image_size: Tuple[int, int, int],
460-
padding: List[int],
460+
padding: Optional[Sequence[int]] = None,
461461
pad_if_needed: bool = False,
462462
fill: int = 0,
463463
padding_mode: str = "constant",
@@ -481,7 +481,7 @@ def random_pad_image_pil(
481481
img: PIL.Image.Image,
482482
output_size: List[int],
483483
image_size: Tuple[int, int, int],
484-
padding: List[int],
484+
padding: Optional[Sequence[int]] = None,
485485
pad_if_needed: bool = False,
486486
fill: Union[int, str, Sequence[int]] = 0,
487487
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",

0 commit comments

Comments
 (0)