Skip to content

Commit 997384c

Browse files
authored
port tests for RandomPhotometricDistort (#7973)
1 parent ace9221 commit 997384c

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ class TestSmoke:
120120
(transforms.RandomEqualize(p=1.0), None),
121121
(transforms.RandomInvert(p=1.0), None),
122122
(transforms.RandomChannelPermutation(), None),
123-
(transforms.RandomPhotometricDistort(p=1.0), None),
124123
(transforms.RandomPosterize(bits=4, p=1.0), None),
125124
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
126125
(transforms.CenterCrop([16, 16]), None),

test/test_transforms_v2_refactored.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4040,3 +4040,28 @@ def test_transform_params_correctness(self, side_range, make_input, device):
40404040
assert 0 <= padding[1] <= (side_range[1] - 1) * height
40414041
assert 0 <= padding[2] <= (side_range[1] - 1) * width
40424042
assert 0 <= padding[3] <= (side_range[1] - 1) * height
4043+
4044+
4045+
class TestRandomPhotometricDistort:
4046+
# Tests are light because this largely relies on the already tested
4047+
# `adjust_{brightness,contrast,saturation,hue}` and `permute_channels` kernels.
4048+
4049+
@pytest.mark.parametrize(
4050+
"make_input",
4051+
[make_image_tensor, make_image_pil, make_image, make_video],
4052+
)
4053+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4054+
@pytest.mark.parametrize("device", cpu_and_cuda())
4055+
def test_transform(self, make_input, dtype, device):
4056+
if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
4057+
pytest.skip(
4058+
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
4059+
"will degenerate to that anyway."
4060+
)
4061+
4062+
check_transform(
4063+
transforms.RandomPhotometricDistort(
4064+
brightness=(0.3, 0.4), contrast=(0.5, 0.6), saturation=(0.7, 0.8), hue=(-0.1, 0.2), p=1
4065+
),
4066+
make_input(dtype=dtype, device=device),
4067+
)

0 commit comments

Comments
 (0)