Skip to content

Commit ef4718a

Browse files
authored
A temporary fix to windows unittests failing in INaturalistTestCase. (#9007)
1 parent 309bd7a commit ef4718a

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

test/common_utils.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,12 @@
3434

3535
@contextlib.contextmanager
3636
def get_tmp_dir(src=None, **kwargs):
37-
tmp_dir = tempfile.mkdtemp(**kwargs)
38-
if src is not None:
39-
os.rmdir(tmp_dir)
40-
shutil.copytree(src, tmp_dir)
41-
try:
37+
with tempfile.TemporaryDirectory(
38+
**kwargs,
39+
) as tmp_dir:
40+
if src is not None:
41+
shutil.copytree(src, tmp_dir)
4242
yield tmp_dir
43-
finally:
44-
shutil.rmtree(tmp_dir)
4543

4644

4745
def set_rng_seed(seed):

test/test_datasets.py

-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import re
1212
import shutil
1313
import string
14-
import sys
1514
import unittest
1615
import xml.etree.ElementTree as ET
1716
import zipfile
@@ -1904,7 +1903,6 @@ def test_class_to_idx(self):
19041903
assert dataset.class_to_idx == class_to_idx
19051904

19061905

1907-
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
19081906
class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
19091907
DATASET_CLASS = datasets.INaturalist
19101908
FEATURE_TYPES = (PIL.Image.Image, (int, tuple))

torchvision/datasets/inaturalist.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
for fname in files:
114114
self.index.append((dir_index, fname))
115115

116-
self.loader = loader or Image.open
116+
self.loader = loader
117117

118118
def _init_2021(self) -> None:
119119
"""Initialize based on 2021 layout"""
@@ -184,7 +184,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
184184
"""
185185

186186
cat_id, fname = self.index[index]
187-
img = self.loader(os.path.join(self.root, self.all_categories[cat_id], fname))
187+
image_path = os.path.join(self.root, self.all_categories[cat_id], fname)
188+
img = self.loader(image_path) if self.loader is not None else Image.open(image_path)
188189

189190
target: Any = []
190191
for t in self.target_type:

0 commit comments

Comments
 (0)