Skip to content

Commit 1d6a259

Browse files
authored
Fixed error condition in RandomCrop (#6548)
1 parent 112accf commit 1d6a259

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1670,7 +1670,7 @@ def test_random_crop():
16701670
assert result.size(1) == height + 1
16711671
assert result.size(2) == width + 1
16721672

1673-
t = transforms.RandomCrop(48)
1673+
t = transforms.RandomCrop(33)
16741674
img = torch.ones(3, 32, 32)
16751675
with pytest.raises(ValueError, match=r"Required crop size .+ is larger than input image size .+"):
16761676
t(img)

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
443443
if height < output_height:
444444
height += 2 * (output_height - height)
445445

446-
if height + 1 < output_height or width + 1 < output_width:
446+
if height < output_height or width < output_width:
447447
raise ValueError(
448448
f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}"
449449
)

torchvision/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int
628628
_, h, w = F.get_dimensions(img)
629629
th, tw = output_size
630630

631-
if h + 1 < th or w + 1 < tw:
631+
if h < th or w < tw:
632632
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
633633

634634
if w == tw and h == th:

0 commit comments

Comments
 (0)