@@ -1512,8 +1512,7 @@ def test_transform(self, input_type, device):
1512
1512
@pytest .mark .parametrize (
1513
1513
"interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
1514
1514
)
1515
- # TODO: investigate why expand=True leads to different shapes between PIL and tensor
1516
- @pytest .mark .parametrize ("expand" , [False ])
1515
+ @pytest .mark .parametrize ("expand" , [False , True ])
1517
1516
@pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
1518
1517
def test_functional_image_correctness (self , angle , center , interpolation , expand , fill ):
1519
1518
image = make_input (torch .Tensor , dtype = torch .uint8 , device = "cpu" )
@@ -1534,8 +1533,7 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
1534
1533
@pytest .mark .parametrize (
1535
1534
"interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
1536
1535
)
1537
- # TODO: investigate why expand=True leads to different shapes between PIL and tensor
1538
- @pytest .mark .parametrize ("expand" , [False ])
1536
+ @pytest .mark .parametrize ("expand" , [False , True ])
1539
1537
@pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
1540
1538
@pytest .mark .parametrize ("seed" , list (range (5 )))
1541
1539
def test_transform_image_correctness (self , center , interpolation , expand , fill , seed ):
0 commit comments