1
1
import collections .abc
2
2
import math
3
3
import warnings
4
- from typing import Any , Dict , List , Union , Sequence , Tuple , cast
4
+ from typing import Any , Dict , List , Union , Sequence , Tuple , cast , Literal
5
5
6
6
import PIL .Image
7
7
import torch
@@ -259,7 +259,14 @@ def apply_recursively(obj: Any) -> Any:
259
259
260
260
261
261
class RandomCrop (Transform ):
262
- def __init__ (self , size , padding = None , pad_if_needed = False , fill = 0 , padding_mode = "constant" ):
262
+ def __init__ (
263
+ self ,
264
+ size = Union [int , Sequence [int ]],
265
+ padding : Sequence [int ] = None ,
266
+ pad_if_needed : bool = False ,
267
+ fill : Union [int , str , Sequence [int ]] = 0 ,
268
+ padding_mode : Union [str , Literal ["constant" , "edge" , "reflect" , "symmetric" ]] = "constant" ,
269
+ ):
263
270
super ().__init__ ()
264
271
self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
265
272
@@ -269,12 +276,9 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode
269
276
self .padding_mode = padding_mode
270
277
271
278
def _get_params (self , sample : Any ) -> Dict [str , Any ]:
272
-
273
279
"""Get parameters for ``crop`` for a random crop.
274
-
275
280
Args:
276
281
sample (PIL Image, Tensor or features.Image): Image to be cropped.
277
-
278
282
Returns:
279
283
dict: Dict containing 'top', 'left', 'height', and 'width'
280
284
"""
@@ -294,19 +298,51 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
294
298
return dict (top = i , left = j , height = th , width = tw )
295
299
296
300
def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
297
-
298
301
if isinstance (input , features .Image ):
299
- output = F .random_crop_image_tensor (input , ** params , padding = self . padding )
300
- input = features .Image .new_like (input , output )
302
+ output = F .crop_image_tensor (input , ** params )
303
+ return features .Image .new_like (input , output )
301
304
elif isinstance (input , PIL .Image .Image ):
302
- input = F .random_crop_image_pil (input , ** params )
305
+ return F .crop_image_pil (input , ** params )
306
+ elif is_simple_tensor (input ):
307
+ return F .crop_image_tensor (input , ** params )
303
308
else :
304
- input = F .random_crop_image_tensor (input , ** params )
305
-
306
- return input
309
+ return input
307
310
308
311
def forward (self , * inputs : Any ) -> Any :
309
312
sample = inputs if len (inputs ) > 1 else inputs [0 ]
310
313
if has_any (sample , features .BoundingBox , features .SegmentationMask ):
311
314
raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
315
+
316
+ if isinstance (sample , features .Image ):
317
+ output = F .random_pad_image_tensor (
318
+ sample ,
319
+ output_size = self .size ,
320
+ image_size = get_image_dimensions (sample ),
321
+ padding = self .padding ,
322
+ pad_if_needed = self .pad_if_needed ,
323
+ fill = self .fill ,
324
+ padding_mode = self .padding_mode ,
325
+ )
326
+ sample = features .Image .new_like (sample , output )
327
+ elif isinstance (sample , PIL .Image .Image ):
328
+ sample = F .random_pad_image_pil (
329
+ sample ,
330
+ output_size = self .size ,
331
+ image_size = get_image_dimensions (sample ),
332
+ padding = self .padding ,
333
+ pad_if_needed = self .pad_if_needed ,
334
+ fill = self .fill ,
335
+ padding_mode = self .padding_mode ,
336
+ )
337
+ elif is_simple_tensor (sample ):
338
+ sample = F .random_pad_image_tensor (
339
+ sample ,
340
+ output_size = self .size ,
341
+ image_size = get_image_dimensions (sample ),
342
+ padding = self .padding ,
343
+ pad_if_needed = self .pad_if_needed ,
344
+ fill = self .fill ,
345
+ padding_mode = self .padding_mode ,
346
+ )
347
+
312
348
return super ().forward (sample )
0 commit comments