Skip to content

Commit 956025b

Browse files
feat: expose loader parameter in FlowDataset type, except `Flying… (#8972)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 997348d commit 956025b

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

test/test_datasets.py

+8
Original file line numberDiff line numberDiff line change
@@ -2038,6 +2038,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
20382038

20392039
FLOW_H, FLOW_W = 3, 4
20402040

2041+
SUPPORT_TV_IMAGE_DECODE = True
2042+
20412043
def inject_fake_data(self, tmpdir, config):
20422044
root = pathlib.Path(tmpdir) / "Sintel"
20432045

@@ -2104,6 +2106,8 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
21042106
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
21052107
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
21062108

2109+
SUPPORT_TV_IMAGE_DECODE = True
2110+
21072111
def inject_fake_data(self, tmpdir, config):
21082112
root = pathlib.Path(tmpdir) / "KittiFlow"
21092113

@@ -2223,6 +2227,8 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
22232227

22242228
FLOW_H, FLOW_W = 3, 4
22252229

2230+
SUPPORT_TV_IMAGE_DECODE = True
2231+
22262232
def inject_fake_data(self, tmpdir, config):
22272233
root = pathlib.Path(tmpdir) / "FlyingThings3D"
22282234

@@ -2289,6 +2295,8 @@ def test_bad_input(self):
22892295
class HD1KTestCase(KittiFlowTestCase):
22902296
DATASET_CLASS = datasets.HD1K
22912297

2298+
SUPPORT_TV_IMAGE_DECODE = True
2299+
22922300
def inject_fake_data(self, tmpdir, config):
22932301
root = pathlib.Path(tmpdir) / "hd1k"
22942302

torchvision/datasets/_optical_flow.py

+45-15
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from abc import ABC, abstractmethod
44
from glob import glob
55
from pathlib import Path
6-
from typing import Callable, List, Optional, Tuple, Union
6+
from typing import Any, Callable, List, Optional, Tuple, Union
77

88
import numpy as np
99
import torch
1010
from PIL import Image
1111

1212
from ..io.image import decode_png, read_file
13+
from .folder import default_loader
1314
from .utils import _read_pfm, verify_str_arg
1415
from .vision import VisionDataset
1516

@@ -32,19 +33,22 @@ class FlowDataset(ABC, VisionDataset):
3233
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
3334
_has_builtin_flow_mask = False
3435

35-
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
36+
def __init__(
37+
self,
38+
root: Union[str, Path],
39+
transforms: Optional[Callable] = None,
40+
loader: Callable[[str], Any] = default_loader,
41+
) -> None:
3642

3743
super().__init__(root=root)
3844
self.transforms = transforms
3945

4046
self._flow_list: List[str] = []
4147
self._image_list: List[List[str]] = []
48+
self._loader = loader
4249

43-
def _read_img(self, file_name: str) -> Image.Image:
44-
img = Image.open(file_name)
45-
if img.mode != "RGB":
46-
img = img.convert("RGB") # type: ignore[assignment]
47-
return img
50+
def _read_img(self, file_name: str) -> Union[Image.Image, torch.Tensor]:
51+
return self._loader(file_name)
4852

4953
@abstractmethod
5054
def _read_flow(self, file_name: str):
@@ -70,9 +74,9 @@ def __getitem__(self, index: int) -> Union[T1, T2]:
7074

7175
if self._has_builtin_flow_mask or valid_flow_mask is not None:
7276
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
73-
return img1, img2, flow, valid_flow_mask
77+
return img1, img2, flow, valid_flow_mask # type: ignore[return-value]
7478
else:
75-
return img1, img2, flow
79+
return img1, img2, flow # type: ignore[return-value]
7680

7781
def __len__(self) -> int:
7882
return len(self._image_list)
@@ -120,6 +124,9 @@ class Sintel(FlowDataset):
120124
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
121125
``valid_flow_mask`` is expected for consistency with other datasets which
122126
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
127+
loader (callable, optional): A function to load an image given its path.
128+
By default, it uses PIL as its image loader, but users could also pass in
129+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
123130
"""
124131

125132
def __init__(
@@ -128,8 +135,9 @@ def __init__(
128135
split: str = "train",
129136
pass_name: str = "clean",
130137
transforms: Optional[Callable] = None,
138+
loader: Callable[[str], Any] = default_loader,
131139
) -> None:
132-
super().__init__(root=root, transforms=transforms)
140+
super().__init__(root=root, transforms=transforms, loader=loader)
133141

134142
verify_str_arg(split, "split", valid_values=("train", "test"))
135143
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
@@ -186,12 +194,21 @@ class KittiFlow(FlowDataset):
186194
split (string, optional): The dataset split, either "train" (default) or "test"
187195
transforms (callable, optional): A function/transform that takes in
188196
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
197+
loader (callable, optional): A function to load an image given its path.
198+
By default, it uses PIL as its image loader, but users could also pass in
199+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
189200
"""
190201

191202
_has_builtin_flow_mask = True
192203

193-
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
194-
super().__init__(root=root, transforms=transforms)
204+
def __init__(
205+
self,
206+
root: Union[str, Path],
207+
split: str = "train",
208+
transforms: Optional[Callable] = None,
209+
loader: Callable[[str], Any] = default_loader,
210+
) -> None:
211+
super().__init__(root=root, transforms=transforms, loader=loader)
195212

196213
verify_str_arg(split, "split", valid_values=("train", "test"))
197214

@@ -324,6 +341,9 @@ class FlyingThings3D(FlowDataset):
324341
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
325342
``valid_flow_mask`` is expected for consistency with other datasets which
326343
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
344+
loader (callable, optional): A function to load an image given its path.
345+
By default, it uses PIL as its image loader, but users could also pass in
346+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
327347
"""
328348

329349
def __init__(
@@ -333,8 +353,9 @@ def __init__(
333353
pass_name: str = "clean",
334354
camera: str = "left",
335355
transforms: Optional[Callable] = None,
356+
loader: Callable[[str], Any] = default_loader,
336357
) -> None:
337-
super().__init__(root=root, transforms=transforms)
358+
super().__init__(root=root, transforms=transforms, loader=loader)
338359

339360
verify_str_arg(split, "split", valid_values=("train", "test"))
340361
split = split.upper()
@@ -414,12 +435,21 @@ class HD1K(FlowDataset):
414435
split (string, optional): The dataset split, either "train" (default) or "test"
415436
transforms (callable, optional): A function/transform that takes in
416437
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
438+
loader (callable, optional): A function to load an image given its path.
439+
By default, it uses PIL as its image loader, but users could also pass in
440+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
417441
"""
418442

419443
_has_builtin_flow_mask = True
420444

421-
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
422-
super().__init__(root=root, transforms=transforms)
445+
def __init__(
446+
self,
447+
root: Union[str, Path],
448+
split: str = "train",
449+
transforms: Optional[Callable] = None,
450+
loader: Callable[[str], Any] = default_loader,
451+
) -> None:
452+
super().__init__(root=root, transforms=transforms, loader=loader)
423453

424454
verify_str_arg(split, "split", valid_values=("train", "test"))
425455

0 commit comments

Comments
 (0)