3
3
from abc import ABC , abstractmethod
4
4
from glob import glob
5
5
from pathlib import Path
6
- from typing import Callable , List , Optional , Tuple , Union
6
+ from typing import Any , Callable , List , Optional , Tuple , Union
7
7
8
8
import numpy as np
9
9
import torch
10
10
from PIL import Image
11
11
12
12
from ..io .image import decode_png , read_file
13
+ from .folder import default_loader
13
14
from .utils import _read_pfm , verify_str_arg
14
15
from .vision import VisionDataset
15
16
@@ -32,19 +33,22 @@ class FlowDataset(ABC, VisionDataset):
32
33
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
33
34
_has_builtin_flow_mask = False
34
35
35
- def __init__ (self , root : Union [str , Path ], transforms : Optional [Callable ] = None ) -> None :
36
+ def __init__ (
37
+ self ,
38
+ root : Union [str , Path ],
39
+ transforms : Optional [Callable ] = None ,
40
+ loader : Callable [[str ], Any ] = default_loader ,
41
+ ) -> None :
36
42
37
43
super ().__init__ (root = root )
38
44
self .transforms = transforms
39
45
40
46
self ._flow_list : List [str ] = []
41
47
self ._image_list : List [List [str ]] = []
48
+ self ._loader = loader
42
49
43
- def _read_img (self , file_name : str ) -> Image .Image :
44
- img = Image .open (file_name )
45
- if img .mode != "RGB" :
46
- img = img .convert ("RGB" ) # type: ignore[assignment]
47
- return img
50
+ def _read_img (self , file_name : str ) -> Union [Image .Image , torch .Tensor ]:
51
+ return self ._loader (file_name )
48
52
49
53
@abstractmethod
50
54
def _read_flow (self , file_name : str ):
@@ -70,9 +74,9 @@ def __getitem__(self, index: int) -> Union[T1, T2]:
70
74
71
75
if self ._has_builtin_flow_mask or valid_flow_mask is not None :
72
76
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
73
- return img1 , img2 , flow , valid_flow_mask
77
+ return img1 , img2 , flow , valid_flow_mask # type: ignore[return-value]
74
78
else :
75
- return img1 , img2 , flow
79
+ return img1 , img2 , flow # type: ignore[return-value]
76
80
77
81
def __len__ (self ) -> int :
78
82
return len (self ._image_list )
@@ -120,6 +124,9 @@ class Sintel(FlowDataset):
120
124
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
121
125
``valid_flow_mask`` is expected for consistency with other datasets which
122
126
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
127
+ loader (callable, optional): A function to load an image given its path.
128
+ By default, it uses PIL as its image loader, but users could also pass in
129
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
123
130
"""
124
131
125
132
def __init__ (
@@ -128,8 +135,9 @@ def __init__(
128
135
split : str = "train" ,
129
136
pass_name : str = "clean" ,
130
137
transforms : Optional [Callable ] = None ,
138
+ loader : Callable [[str ], Any ] = default_loader ,
131
139
) -> None :
132
- super ().__init__ (root = root , transforms = transforms )
140
+ super ().__init__ (root = root , transforms = transforms , loader = loader )
133
141
134
142
verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
135
143
verify_str_arg (pass_name , "pass_name" , valid_values = ("clean" , "final" , "both" ))
@@ -186,12 +194,21 @@ class KittiFlow(FlowDataset):
186
194
split (string, optional): The dataset split, either "train" (default) or "test"
187
195
transforms (callable, optional): A function/transform that takes in
188
196
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
197
+ loader (callable, optional): A function to load an image given its path.
198
+ By default, it uses PIL as its image loader, but users could also pass in
199
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
189
200
"""
190
201
191
202
_has_builtin_flow_mask = True
192
203
193
- def __init__ (self , root : Union [str , Path ], split : str = "train" , transforms : Optional [Callable ] = None ) -> None :
194
- super ().__init__ (root = root , transforms = transforms )
204
+ def __init__ (
205
+ self ,
206
+ root : Union [str , Path ],
207
+ split : str = "train" ,
208
+ transforms : Optional [Callable ] = None ,
209
+ loader : Callable [[str ], Any ] = default_loader ,
210
+ ) -> None :
211
+ super ().__init__ (root = root , transforms = transforms , loader = loader )
195
212
196
213
verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
197
214
@@ -324,6 +341,9 @@ class FlyingThings3D(FlowDataset):
324
341
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
325
342
``valid_flow_mask`` is expected for consistency with other datasets which
326
343
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
344
+ loader (callable, optional): A function to load an image given its path.
345
+ By default, it uses PIL as its image loader, but users could also pass in
346
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
327
347
"""
328
348
329
349
def __init__ (
@@ -333,8 +353,9 @@ def __init__(
333
353
pass_name : str = "clean" ,
334
354
camera : str = "left" ,
335
355
transforms : Optional [Callable ] = None ,
356
+ loader : Callable [[str ], Any ] = default_loader ,
336
357
) -> None :
337
- super ().__init__ (root = root , transforms = transforms )
358
+ super ().__init__ (root = root , transforms = transforms , loader = loader )
338
359
339
360
verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
340
361
split = split .upper ()
@@ -414,12 +435,21 @@ class HD1K(FlowDataset):
414
435
split (string, optional): The dataset split, either "train" (default) or "test"
415
436
transforms (callable, optional): A function/transform that takes in
416
437
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
438
+ loader (callable, optional): A function to load an image given its path.
439
+ By default, it uses PIL as its image loader, but users could also pass in
440
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
417
441
"""
418
442
419
443
_has_builtin_flow_mask = True
420
444
421
- def __init__ (self , root : Union [str , Path ], split : str = "train" , transforms : Optional [Callable ] = None ) -> None :
422
- super ().__init__ (root = root , transforms = transforms )
445
+ def __init__ (
446
+ self ,
447
+ root : Union [str , Path ],
448
+ split : str = "train" ,
449
+ transforms : Optional [Callable ] = None ,
450
+ loader : Callable [[str ], Any ] = default_loader ,
451
+ ) -> None :
452
+ super ().__init__ (root = root , transforms = transforms , loader = loader )
423
453
424
454
verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
425
455
0 commit comments