1
1
import collections .abc
2
2
import math
3
3
import warnings
4
- from typing import Any , Dict , List , Union , Sequence , Tuple , cast , Literal
4
+ from typing import Any , Dict , List , Union , Sequence , Tuple , cast , Literal , Optional
5
5
6
6
import PIL .Image
7
7
import torch
@@ -262,7 +262,7 @@ class RandomCrop(Transform):
262
262
def __init__ (
263
263
self ,
264
264
size : Union [int , Sequence [int ]],
265
- padding : Sequence [int ],
265
+ padding : Optional [ Sequence [int ]] = None ,
266
266
pad_if_needed : bool = False ,
267
267
fill : Union [int , str , Sequence [int ]] = 0 ,
268
268
padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
@@ -275,15 +275,15 @@ def __init__(
275
275
self .fill = fill
276
276
self .padding_mode = padding_mode
277
277
278
- def _get_params (self , sample : Any ) -> Dict [str , Any ]:
278
+ def _get_crop_parameters (self , image : Any ) -> Dict [str , Any ]:
279
279
"""Get parameters for ``crop`` for a random crop.
280
280
Args:
281
281
sample (PIL Image, Tensor or features.Image): Image to be cropped.
282
282
Returns:
283
283
dict: Dict containing 'top', 'left', 'height', and 'width'
284
284
"""
285
285
286
- _ , h , w = get_image_dimensions (sample )
286
+ _ , h , w = get_image_dimensions (image )
287
287
288
288
th , tw = self .size
289
289
@@ -298,51 +298,54 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
298
298
return dict (top = i , left = j , height = th , width = tw )
299
299
300
300
def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
301
- if isinstance (input , features .Image ):
302
- output = F .crop_image_tensor (input , ** params )
303
- return features .Image .new_like (input , output )
304
- elif isinstance (input , PIL .Image .Image ):
305
- return F .crop_image_pil (input , ** params )
306
- elif is_simple_tensor (input ):
307
- return F .crop_image_tensor (input , ** params )
308
- else :
309
- return input
310
301
311
- def forward (self , * inputs : Any ) -> Any :
312
- sample = inputs if len (inputs ) > 1 else inputs [0 ]
313
- if has_any (sample , features .BoundingBox , features .SegmentationMask ):
314
- raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
315
-
316
- if isinstance (sample , features .Image ):
302
+ if isinstance (input , features .Image ):
317
303
output = F .random_pad_image_tensor (
318
- sample ,
304
+ input ,
319
305
output_size = self .size ,
320
- image_size = get_image_dimensions (sample ),
306
+ image_size = get_image_dimensions (input ),
321
307
padding = cast (List [int ], tuple (self .padding )),
322
308
pad_if_needed = self .pad_if_needed ,
323
309
fill = self .fill ,
324
310
padding_mode = self .padding_mode ,
325
311
)
326
- sample = features .Image .new_like (sample , output )
327
- elif isinstance (sample , PIL .Image .Image ):
328
- sample = F .random_pad_image_pil (
329
- sample ,
312
+ input = features .Image .new_like (input , output )
313
+ elif isinstance (input , PIL .Image .Image ):
314
+ input = F .random_pad_image_pil (
315
+ input ,
330
316
output_size = self .size ,
331
- image_size = get_image_dimensions (sample ),
332
- padding = cast ( List [ int ], tuple ( self .padding )) ,
317
+ image_size = get_image_dimensions (input ),
318
+ padding = self .padding ,
333
319
pad_if_needed = self .pad_if_needed ,
334
320
fill = self .fill ,
335
321
padding_mode = self .padding_mode ,
336
322
)
337
- elif is_simple_tensor (sample ):
338
- sample = F .random_pad_image_tensor (
339
- sample ,
323
+ elif is_simple_tensor (input ):
324
+ input = F .random_pad_image_tensor (
325
+ input ,
340
326
output_size = self .size ,
341
- image_size = get_image_dimensions (sample ),
342
- padding = cast ( List [ int ], tuple ( self .padding )) ,
327
+ image_size = get_image_dimensions (input ),
328
+ padding = self .padding ,
343
329
pad_if_needed = self .pad_if_needed ,
344
330
fill = self .fill , # TODO: should be converted to number
345
331
padding_mode = self .padding_mode ,
346
332
)
347
333
334
+ params .update (self ._get_crop_parameters (input ))
335
+
336
+ if isinstance (input , features .Image ):
337
+ output = F .crop_image_tensor (input , ** params )
338
+ return features .Image .new_like (input , output )
339
+ elif isinstance (input , PIL .Image .Image ):
340
+ return F .crop_image_pil (input , ** params )
341
+ elif is_simple_tensor (input ):
342
+ return F .crop_image_tensor (input , ** params )
343
+ else :
344
+ return input
345
+
346
+ def forward (self , * inputs : Any ) -> Any :
347
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
348
+ if has_any (sample , features .BoundingBox , features .SegmentationMask ):
349
+ raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
350
+
348
351
return super ().forward (sample )
0 commit comments