Skip to content

Commit 309bd7a

Browse files
feat: add loader to Omniglot and INaturalist's argument. (#8945)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 956025b commit 309bd7a

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

test/test_datasets.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import re
1212
import shutil
1313
import string
14+
import sys
1415
import unittest
1516
import xml.etree.ElementTree as ET
1617
import zipfile
@@ -1146,6 +1147,7 @@ class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
11461147
DATASET_CLASS = datasets.Omniglot
11471148

11481149
ADDITIONAL_CONFIGS = combinations_grid(background=(True, False))
1150+
SUPPORT_TV_IMAGE_DECODE = True
11491151

11501152
def inject_fake_data(self, tmpdir, config):
11511153
target_folder = (
@@ -1902,6 +1904,7 @@ def test_class_to_idx(self):
19021904
assert dataset.class_to_idx == class_to_idx
19031905

19041906

1907+
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
19051908
class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
19061909
DATASET_CLASS = datasets.INaturalist
19071910
FEATURE_TYPES = (PIL.Image.Image, (int, tuple))
@@ -1910,6 +1913,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
19101913
target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]),
19111914
version=("2021_train",),
19121915
)
1916+
SUPPORT_TV_IMAGE_DECODE = True
19131917

19141918
def inject_fake_data(self, tmpdir, config):
19151919
categories = [

torchvision/datasets/inaturalist.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class INaturalist(VisionDataset):
6262
download (bool, optional): If true, downloads the dataset from the internet and
6363
puts it in root directory. If dataset is already downloaded, it is not
6464
downloaded again.
65+
loader (callable, optional): A function to load an image given its path.
66+
By default, it uses PIL as its image loader, but users could also pass in
67+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
6568
"""
6669

6770
def __init__(
@@ -72,6 +75,7 @@ def __init__(
7275
transform: Optional[Callable] = None,
7376
target_transform: Optional[Callable] = None,
7477
download: bool = False,
78+
loader: Optional[Callable[[Union[str, Path]], Any]] = None,
7579
) -> None:
7680
self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
7781

@@ -109,6 +113,8 @@ def __init__(
109113
for fname in files:
110114
self.index.append((dir_index, fname))
111115

116+
self.loader = loader or Image.open
117+
112118
def _init_2021(self) -> None:
113119
"""Initialize based on 2021 layout"""
114120

@@ -178,7 +184,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
178184
"""
179185

180186
cat_id, fname = self.index[index]
181-
img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
187+
img = self.loader(os.path.join(self.root, self.all_categories[cat_id], fname))
182188

183189
target: Any = []
184190
for t in self.target_type:

torchvision/datasets/omniglot.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class Omniglot(VisionDataset):
2323
download (bool, optional): If true, downloads the dataset zip files from the internet and
2424
puts it in root directory. If the zip files are already downloaded, they are not
2525
downloaded again.
26+
loader (callable, optional): A function to load an image given its path.
27+
By default, it uses PIL as its image loader, but users could also pass in
28+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2629
"""
2730

2831
folder = "omniglot-py"
@@ -39,6 +42,7 @@ def __init__(
3942
transform: Optional[Callable] = None,
4043
target_transform: Optional[Callable] = None,
4144
download: bool = False,
45+
loader: Optional[Callable[[Union[str, Path]], Any]] = None,
4246
) -> None:
4347
super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
4448
self.background = background
@@ -59,6 +63,7 @@ def __init__(
5963
for idx, character in enumerate(self._characters)
6064
]
6165
self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
66+
self.loader = loader
6267

6368
def __len__(self) -> int:
6469
return len(self._flat_character_images)
@@ -73,7 +78,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
7378
"""
7479
image_name, character_class = self._flat_character_images[index]
7580
image_path = join(self.target_folder, self._characters[character_class], image_name)
76-
image = Image.open(image_path, mode="r").convert("L")
81+
image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path)
7782

7883
if self.transform:
7984
image = self.transform(image)

0 commit comments

Comments
 (0)