17
17
18
18
import abc
19
19
import collections
20
+ from collections .abc import Iterator
20
21
import contextlib
21
22
import functools
22
23
import importlib
23
24
import inspect
24
25
import os .path
25
- from typing import ClassVar , Dict , Iterator , List , Type , Text , Tuple
26
+ import time
27
+ from typing import ClassVar , Type
26
28
29
+ from absl import logging
27
30
from etils import epath
28
31
from tensorflow_datasets .core import constants
29
32
from tensorflow_datasets .core import naming
30
33
from tensorflow_datasets .core import visibility
34
+ import tensorflow_datasets .core .logging as _tfds_logging
35
+ from tensorflow_datasets .core .logging import call_metadata as _call_metadata
31
36
from tensorflow_datasets .core .utils import py_utils
32
37
from tensorflow_datasets .core .utils import resource_utils
33
38
38
43
# <str snake_cased_name, abstract DatasetBuilder subclass>
39
44
_ABSTRACT_DATASET_REGISTRY = {}
40
45
41
- # Keep track of Dict [str (module name), List [DatasetBuilder]]
46
+ # Keep track of dict [str (module name), list [DatasetBuilder]]
42
47
# This is directly accessed by `tfds.community.builder_cls_from_module` when
43
48
# importing community packages.
44
49
_MODULE_TO_DATASETS = collections .defaultdict (list )
51
56
# <str snake_cased_name, abstract DatasetCollectionBuilder subclass>
52
57
_ABSTRACT_DATASET_COLLECTION_REGISTRY = {}
53
58
54
- # Keep track of Dict [str (module name), List [DatasetCollectionBuilder]]
59
+ # Keep track of dict [str (module name), list [DatasetCollectionBuilder]]
55
60
_MODULE_TO_DATASET_COLLECTIONS = collections .defaultdict (list )
56
61
57
62
# eg for dataset "foo": "tensorflow_datasets.datasets.foo.foo_dataset_builder".
@@ -80,6 +85,70 @@ def skip_registration() -> Iterator[None]:
80
85
_skip_registration = False
81
86
82
87
88
+ @functools .cache
89
+ def _import_legacy_builders () -> None :
90
+ """Imports legacy builders."""
91
+ modules_to_import = [
92
+ 'audio' ,
93
+ 'graphs' ,
94
+ 'image' ,
95
+ 'image_classification' ,
96
+ 'object_detection' ,
97
+ 'nearest_neighbors' ,
98
+ 'question_answering' ,
99
+ 'd4rl' ,
100
+ 'ranking' ,
101
+ 'recommendation' ,
102
+ 'rl_unplugged' ,
103
+ 'rlds.datasets' ,
104
+ 'robotics' ,
105
+ 'robomimic' ,
106
+ 'structured' ,
107
+ 'summarization' ,
108
+ 'text' ,
109
+ 'text_simplification' ,
110
+ 'time_series' ,
111
+ 'translate' ,
112
+ 'video' ,
113
+ 'vision_language' ,
114
+ ]
115
+
116
+ before_dataset_imports = time .time ()
117
+ metadata = _call_metadata .CallMetadata ()
118
+ metadata .start_time_micros = int (before_dataset_imports * 1e6 )
119
+ try :
120
+ # For builds that don't include all dataset builders, we don't want to fail
121
+ # on import errors of dataset builders.
122
+ try :
123
+ for module in modules_to_import :
124
+ importlib .import_module (f'tensorflow_datasets.{ module } ' )
125
+ except (ImportError , ModuleNotFoundError ):
126
+ pass
127
+
128
+ except Exception as exception : # pylint: disable=broad-except
129
+ metadata .mark_error ()
130
+ logging .exception (exception )
131
+ finally :
132
+ import_time_ms_dataset_builders = int (
133
+ (time .time () - before_dataset_imports ) * 1000
134
+ )
135
+ metadata .mark_end ()
136
+ _tfds_logging .tfds_import (
137
+ metadata = metadata ,
138
+ import_time_ms_tensorflow = 0 ,
139
+ import_time_ms_dataset_builders = import_time_ms_dataset_builders ,
140
+ )
141
+
142
+
143
+ @functools .cache
144
+ def _import_dataset_collections () -> None :
145
+ """Imports dataset collections."""
146
+ try :
147
+ importlib .import_module ('tensorflow_datasets.dataset_collections' )
148
+ except (ImportError , ModuleNotFoundError ):
149
+ pass
150
+
151
+
83
152
# The implementation of this class follows closely RegisteredDataset.
84
153
class RegisteredDatasetCollection (abc .ABC ):
85
154
"""Subclasses will be registered and given a `name` property."""
@@ -129,23 +198,24 @@ def __init_subclass__(cls, skip_registration=False, **kwargs): # pylint: disabl
129
198
_DATASET_COLLECTION_REGISTRY [cls .name ] = cls
130
199
131
200
132
- def list_imported_dataset_collections () -> List [str ]:
201
+ def list_imported_dataset_collections () -> list [str ]:
133
202
"""Returns the string names of all `tfds.core.DatasetCollection`s."""
134
- all_dataset_collections = [
135
- dataset_collection_name
136
- for dataset_collection_name , dataset_collection_cls in _DATASET_COLLECTION_REGISTRY .items ()
137
- ]
203
+ _import_dataset_collections ()
204
+ all_dataset_collections = list (_DATASET_COLLECTION_REGISTRY .keys ())
138
205
return sorted (all_dataset_collections )
139
206
140
207
141
208
def is_dataset_collection (name : str ) -> bool :
209
+ _import_dataset_collections ()
142
210
return name in _DATASET_COLLECTION_REGISTRY
143
211
144
212
145
213
def imported_dataset_collection_cls (
146
214
name : str ,
147
215
) -> Type [RegisteredDatasetCollection ]:
148
216
"""Returns the Registered dataset class."""
217
+ _import_dataset_collections ()
218
+
149
219
if name in _ABSTRACT_DATASET_COLLECTION_REGISTRY :
150
220
raise AssertionError (f'DatasetCollection { name } is an abstract class.' )
151
221
@@ -224,8 +294,9 @@ def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
224
294
return visibility .DatasetType .TFDS_PUBLIC .is_available ()
225
295
226
296
227
- def list_imported_builders () -> List [str ]:
297
+ def list_imported_builders () -> list [str ]:
228
298
"""Returns the string names of all `tfds.core.DatasetBuilder`s."""
299
+ _import_legacy_builders ()
229
300
all_builders = [
230
301
builder_name
231
302
for builder_name , builder_cls in _DATASET_REGISTRY .items ()
@@ -236,8 +307,8 @@ def list_imported_builders() -> List[str]:
236
307
237
308
@functools .lru_cache (maxsize = None )
238
309
def _get_existing_dataset_packages (
239
- datasets_dir : Text ,
240
- ) -> Dict [ Text , Tuple [epath .Path , Text ]]:
310
+ datasets_dir : str ,
311
+ ) -> dict [ str , tuple [epath .Path , str ]]:
241
312
"""Returns existing datasets.
242
313
243
314
Args:
@@ -293,7 +364,12 @@ def imported_builder_cls(name: str) -> Type[RegisteredDataset]:
293
364
raise AssertionError (f'Dataset { name } is an abstract class.' )
294
365
295
366
if name not in _DATASET_REGISTRY :
296
- raise DatasetNotFoundError (f'Dataset { name } not found.' )
367
+ # Dataset not found in the registry, try to import legacy builders.
368
+ # Dataset builders are imported lazily to avoid slowing down the startup
369
+ # of the binary.
370
+ _import_legacy_builders ()
371
+ if name not in _DATASET_REGISTRY :
372
+ raise DatasetNotFoundError (f'Dataset { name } not found.' )
297
373
298
374
builder_cls = _DATASET_REGISTRY [name ]
299
375
if not _is_builder_available (builder_cls ):
0 commit comments