Skip to content

Commit 4fd9a5d

Browse files
jblespiauThe TensorFlow Datasets Authors
authored andcommitted
Lazily import legacy dataset builders
PiperOrigin-RevId: 638251592
1 parent ea633bb commit 4fd9a5d

File tree

4 files changed

+72
-96
lines changed

4 files changed

+72
-96
lines changed

tensorflow_datasets/__init__.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,62 @@
3030
3131
* These API docs
3232
* [Available datasets](https://www.tensorflow.org/datasets/catalog/overview)
33-
* [Colab
34-
tutorial](https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb)
33+
* [Colab tutorial](https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb)
3534
* [Add a dataset](https://www.tensorflow.org/datasets/add_dataset)
3635
"""
3736
# pylint: enable=line-too-long
3837
# pylint: disable=g-import-not-at-top,g-bad-import-order,wrong-import-position,unused-import
3938

39+
import time
40+
41+
_TIMESTAMP_IMPORT_STARTS = time.time()
4042
from absl import logging
43+
import tensorflow_datasets.core.logging as _tfds_logging
44+
from tensorflow_datasets.core.logging import call_metadata as _call_metadata
45+
46+
_metadata = _call_metadata.CallMetadata()
47+
_metadata.start_time_micros = int(_TIMESTAMP_IMPORT_STARTS * 1e6)
48+
_import_time_ms_dataset_builders = 0
4149

4250
try:
43-
from tensorflow_datasets import rlds # pylint: disable=g-bad-import-order
51+
# Imports for registration
52+
_before_dataset_imports = time.time()
53+
from tensorflow_datasets import dataset_collections
54+
55+
# pytype: disable=import-error
56+
# For builds that don't include all dataset builders, we don't want to fail on
57+
# import errors of dataset builders.
58+
try:
59+
from tensorflow_datasets import audio
60+
from tensorflow_datasets import graphs
61+
from tensorflow_datasets import image
62+
from tensorflow_datasets import image_classification
63+
from tensorflow_datasets import object_detection
64+
from tensorflow_datasets import nearest_neighbors
65+
from tensorflow_datasets import question_answering
66+
from tensorflow_datasets import d4rl
67+
from tensorflow_datasets import ranking
68+
from tensorflow_datasets import recommendation
69+
from tensorflow_datasets import rl_unplugged
70+
from tensorflow_datasets import rlds
71+
from tensorflow_datasets import robotics
72+
from tensorflow_datasets import robomimic
73+
from tensorflow_datasets import structured
74+
from tensorflow_datasets import summarization
75+
from tensorflow_datasets import text
76+
from tensorflow_datasets import text_simplification
77+
from tensorflow_datasets import time_series
78+
from tensorflow_datasets import translate
79+
from tensorflow_datasets import video
80+
from tensorflow_datasets import vision_language
81+
82+
except ImportError:
83+
pass
84+
# pytype: enable=import-error
85+
86+
_import_time_ms_dataset_builders = int(
87+
(time.time() - _before_dataset_imports) * 1000
88+
)
4489

4590
# Public API to create and generate a dataset
4691
from tensorflow_datasets.public_api import * # pylint: disable=wildcard-import
@@ -49,4 +94,12 @@
4994
__all__ = public_api.__all__
5095

5196
except Exception as exception: # pylint: disable=broad-except
97+
_metadata.mark_error()
5298
logging.exception(exception)
99+
finally:
100+
_metadata.mark_end()
101+
_tfds_logging.tfds_import(
102+
metadata=_metadata,
103+
import_time_ms_tensorflow=0,
104+
import_time_ms_dataset_builders=_import_time_ms_dataset_builders,
105+
)

tensorflow_datasets/core/lazy_builder_import.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
defined under the "datasets/" directory and the legacy locations have been
2020
deprecated (ie: a new major release).
2121
"""
22+
from typing import Text
2223

2324
from absl import logging
2425
from tensorflow_datasets.core import registered
@@ -27,7 +28,7 @@
2728
class LazyBuilderImport:
2829
"""Lazy load DatasetBuilder from given name from legacy locations."""
2930

30-
def __init__(self, dataset_name: str):
31+
def __init__(self, dataset_name: Text):
3132
object.__setattr__(self, "_dataset_name", dataset_name)
3233
object.__setattr__(self, "_dataset_cls", None)
3334

tensorflow_datasets/core/registered.py

Lines changed: 12 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,17 @@
1717

1818
import abc
1919
import collections
20-
from collections.abc import Iterator
2120
import contextlib
2221
import functools
2322
import importlib
2423
import inspect
2524
import os.path
26-
import time
27-
from typing import ClassVar, Type
25+
from typing import ClassVar, Dict, Iterator, List, Type, Text, Tuple
2826

29-
from absl import logging
3027
from etils import epath
3128
from tensorflow_datasets.core import constants
3229
from tensorflow_datasets.core import naming
3330
from tensorflow_datasets.core import visibility
34-
import tensorflow_datasets.core.logging as _tfds_logging
35-
from tensorflow_datasets.core.logging import call_metadata as _call_metadata
3631
from tensorflow_datasets.core.utils import py_utils
3732
from tensorflow_datasets.core.utils import resource_utils
3833

@@ -43,7 +38,7 @@
4338
# <str snake_cased_name, abstract DatasetBuilder subclass>
4439
_ABSTRACT_DATASET_REGISTRY = {}
4540

46-
# Keep track of dict[str (module name), list[DatasetBuilder]]
41+
# Keep track of Dict[str (module name), List[DatasetBuilder]]
4742
# This is directly accessed by `tfds.community.builder_cls_from_module` when
4843
# importing community packages.
4944
_MODULE_TO_DATASETS = collections.defaultdict(list)
@@ -56,7 +51,7 @@
5651
# <str snake_cased_name, abstract DatasetCollectionBuilder subclass>
5752
_ABSTRACT_DATASET_COLLECTION_REGISTRY = {}
5853

59-
# Keep track of dict[str (module name), list[DatasetCollectionBuilder]]
54+
# Keep track of Dict[str (module name), List[DatasetCollectionBuilder]]
6055
_MODULE_TO_DATASET_COLLECTIONS = collections.defaultdict(list)
6156

6257
# eg for dataset "foo": "tensorflow_datasets.datasets.foo.foo_dataset_builder".
@@ -85,70 +80,6 @@ def skip_registration() -> Iterator[None]:
8580
_skip_registration = False
8681

8782

88-
@functools.cache
89-
def _import_legacy_builders() -> None:
90-
"""Imports legacy builders."""
91-
modules_to_import = [
92-
'audio',
93-
'graphs',
94-
'image',
95-
'image_classification',
96-
'object_detection',
97-
'nearest_neighbors',
98-
'question_answering',
99-
'd4rl',
100-
'ranking',
101-
'recommendation',
102-
'rl_unplugged',
103-
'rlds.datasets',
104-
'robotics',
105-
'robomimic',
106-
'structured',
107-
'summarization',
108-
'text',
109-
'text_simplification',
110-
'time_series',
111-
'translate',
112-
'video',
113-
'vision_language',
114-
]
115-
116-
before_dataset_imports = time.time()
117-
metadata = _call_metadata.CallMetadata()
118-
metadata.start_time_micros = int(before_dataset_imports * 1e6)
119-
try:
120-
# For builds that don't include all dataset builders, we don't want to fail
121-
# on import errors of dataset builders.
122-
try:
123-
for module in modules_to_import:
124-
importlib.import_module(f'tensorflow_datasets.{module}')
125-
except (ImportError, ModuleNotFoundError):
126-
pass
127-
128-
except Exception as exception: # pylint: disable=broad-except
129-
metadata.mark_error()
130-
logging.exception(exception)
131-
finally:
132-
import_time_ms_dataset_builders = int(
133-
(time.time() - before_dataset_imports) * 1000
134-
)
135-
metadata.mark_end()
136-
_tfds_logging.tfds_import(
137-
metadata=metadata,
138-
import_time_ms_tensorflow=0,
139-
import_time_ms_dataset_builders=import_time_ms_dataset_builders,
140-
)
141-
142-
143-
@functools.cache
144-
def _import_dataset_collections() -> None:
145-
"""Imports dataset collections."""
146-
try:
147-
importlib.import_module('tensorflow_datasets.dataset_collections')
148-
except (ImportError, ModuleNotFoundError):
149-
pass
150-
151-
15283
# The implementation of this class follows closely RegisteredDataset.
15384
class RegisteredDatasetCollection(abc.ABC):
15485
"""Subclasses will be registered and given a `name` property."""
@@ -198,24 +129,23 @@ def __init_subclass__(cls, skip_registration=False, **kwargs): # pylint: disabl
198129
_DATASET_COLLECTION_REGISTRY[cls.name] = cls
199130

200131

201-
def list_imported_dataset_collections() -> list[str]:
132+
def list_imported_dataset_collections() -> List[str]:
202133
"""Returns the string names of all `tfds.core.DatasetCollection`s."""
203-
_import_dataset_collections()
204-
all_dataset_collections = list(_DATASET_COLLECTION_REGISTRY.keys())
134+
all_dataset_collections = [
135+
dataset_collection_name
136+
for dataset_collection_name, dataset_collection_cls in _DATASET_COLLECTION_REGISTRY.items()
137+
]
205138
return sorted(all_dataset_collections)
206139

207140

208141
def is_dataset_collection(name: str) -> bool:
209-
_import_dataset_collections()
210142
return name in _DATASET_COLLECTION_REGISTRY
211143

212144

213145
def imported_dataset_collection_cls(
214146
name: str,
215147
) -> Type[RegisteredDatasetCollection]:
216148
"""Returns the Registered dataset class."""
217-
_import_dataset_collections()
218-
219149
if name in _ABSTRACT_DATASET_COLLECTION_REGISTRY:
220150
raise AssertionError(f'DatasetCollection {name} is an abstract class.')
221151

@@ -294,9 +224,8 @@ def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
294224
return visibility.DatasetType.TFDS_PUBLIC.is_available()
295225

296226

297-
def list_imported_builders() -> list[str]:
227+
def list_imported_builders() -> List[str]:
298228
"""Returns the string names of all `tfds.core.DatasetBuilder`s."""
299-
_import_legacy_builders()
300229
all_builders = [
301230
builder_name
302231
for builder_name, builder_cls in _DATASET_REGISTRY.items()
@@ -307,8 +236,8 @@ def list_imported_builders() -> list[str]:
307236

308237
@functools.lru_cache(maxsize=None)
309238
def _get_existing_dataset_packages(
310-
datasets_dir: str,
311-
) -> dict[str, tuple[epath.Path, str]]:
239+
datasets_dir: Text,
240+
) -> Dict[Text, Tuple[epath.Path, Text]]:
312241
"""Returns existing datasets.
313242
314243
Args:
@@ -364,12 +293,7 @@ def imported_builder_cls(name: str) -> Type[RegisteredDataset]:
364293
raise AssertionError(f'Dataset {name} is an abstract class.')
365294

366295
if name not in _DATASET_REGISTRY:
367-
# Dataset not found in the registry, try to import legacy builders.
368-
# Dataset builders are imported lazily to avoid slowing down the startup
369-
# of the binary.
370-
_import_legacy_builders()
371-
if name not in _DATASET_REGISTRY:
372-
raise DatasetNotFoundError(f'Dataset {name} not found.')
296+
raise DatasetNotFoundError(f'Dataset {name} not found.')
373297

374298
builder_cls = _DATASET_REGISTRY[name]
375299
if not _is_builder_available(builder_cls):

tensorflow_datasets/rlds/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""RLDS utils.
17-
18-
Note that the datasets are imported lazily so they don't need to be in here.
19-
"""
16+
"""Datasets generated with RLDS."""
2017

18+
from tensorflow_datasets.rlds import datasets
2119
from tensorflow_datasets.rlds import envlogger_reader
2220
from tensorflow_datasets.rlds import rlds_base

0 commit comments

Comments
 (0)