@@ -4040,3 +4040,28 @@ def test_transform_params_correctness(self, side_range, make_input, device):
4040
4040
assert 0 <= padding [1 ] <= (side_range [1 ] - 1 ) * height
4041
4041
assert 0 <= padding [2 ] <= (side_range [1 ] - 1 ) * width
4042
4042
assert 0 <= padding [3 ] <= (side_range [1 ] - 1 ) * height
4043
+
4044
+
4045
+ class TestRandomPhotometricDistort :
4046
+ # Tests are light because this largely relies on the already tested
4047
+ # `adjust_{brightness,contrast,saturation,hue}` and `permute_channels` kernels.
4048
+
4049
+ @pytest .mark .parametrize (
4050
+ "make_input" ,
4051
+ [make_image_tensor , make_image_pil , make_image , make_video ],
4052
+ )
4053
+ @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
4054
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
4055
+ def test_transform (self , make_input , dtype , device ):
4056
+ if make_input is make_image_pil and not (dtype is torch .uint8 and device == "cpu" ):
4057
+ pytest .skip (
4058
+ "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
4059
+ "will degenerate to that anyway."
4060
+ )
4061
+
4062
+ check_transform (
4063
+ transforms .RandomPhotometricDistort (
4064
+ brightness = (0.3 , 0.4 ), contrast = (0.5 , 0.6 ), saturation = (0.7 , 0.8 ), hue = (- 0.1 , 0.2 ), p = 1
4065
+ ),
4066
+ make_input (dtype = dtype , device = device ),
4067
+ )
0 commit comments