Skip to content

Commit a7af2a1

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Lazily import legacy dataset builders
PiperOrigin-RevId: 638201042
1 parent 6bbba45 commit a7af2a1

File tree

4 files changed

+97
-72
lines changed

4 files changed

+97
-72
lines changed

tensorflow_datasets/__init__.py

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,62 +30,17 @@
3030
3131
* These API docs
3232
* [Available datasets](https://www.tensorflow.org/datasets/catalog/overview)
33-
* [Colab tutorial](https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb)
33+
* [Colab
34+
tutorial](https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb)
3435
* [Add a dataset](https://www.tensorflow.org/datasets/add_dataset)
3536
"""
3637
# pylint: enable=line-too-long
3738
# pylint: disable=g-import-not-at-top,g-bad-import-order,wrong-import-position,unused-import
3839

39-
import time
40-
41-
_TIMESTAMP_IMPORT_STARTS = time.time()
4240
from absl import logging
43-
import tensorflow_datasets.core.logging as _tfds_logging
44-
from tensorflow_datasets.core.logging import call_metadata as _call_metadata
45-
46-
_metadata = _call_metadata.CallMetadata()
47-
_metadata.start_time_micros = int(_TIMESTAMP_IMPORT_STARTS * 1e6)
48-
_import_time_ms_dataset_builders = 0
4941

5042
try:
51-
# Imports for registration
52-
_before_dataset_imports = time.time()
53-
from tensorflow_datasets import dataset_collections
54-
55-
# pytype: disable=import-error
56-
# For builds that don't include all dataset builders, we don't want to fail on
57-
# import errors of dataset builders.
58-
try:
59-
from tensorflow_datasets import audio
60-
from tensorflow_datasets import graphs
61-
from tensorflow_datasets import image
62-
from tensorflow_datasets import image_classification
63-
from tensorflow_datasets import object_detection
64-
from tensorflow_datasets import nearest_neighbors
65-
from tensorflow_datasets import question_answering
66-
from tensorflow_datasets import d4rl
67-
from tensorflow_datasets import ranking
68-
from tensorflow_datasets import recommendation
69-
from tensorflow_datasets import rl_unplugged
70-
from tensorflow_datasets import rlds
71-
from tensorflow_datasets import robotics
72-
from tensorflow_datasets import robomimic
73-
from tensorflow_datasets import structured
74-
from tensorflow_datasets import summarization
75-
from tensorflow_datasets import text
76-
from tensorflow_datasets import text_simplification
77-
from tensorflow_datasets import time_series
78-
from tensorflow_datasets import translate
79-
from tensorflow_datasets import video
80-
from tensorflow_datasets import vision_language
81-
82-
except ImportError:
83-
pass
84-
# pytype: enable=import-error
85-
86-
_import_time_ms_dataset_builders = int(
87-
(time.time() - _before_dataset_imports) * 1000
88-
)
43+
from tensorflow_datasets import rlds # pylint: disable=g-bad-import-order
8944

9045
# Public API to create and generate a dataset
9146
from tensorflow_datasets.public_api import * # pylint: disable=wildcard-import
@@ -94,12 +49,4 @@
9449
__all__ = public_api.__all__
9550

9651
except Exception as exception: # pylint: disable=broad-except
97-
_metadata.mark_error()
9852
logging.exception(exception)
99-
finally:
100-
_metadata.mark_end()
101-
_tfds_logging.tfds_import(
102-
metadata=_metadata,
103-
import_time_ms_tensorflow=0,
104-
import_time_ms_dataset_builders=_import_time_ms_dataset_builders,
105-
)

tensorflow_datasets/core/lazy_builder_import.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
defined under the "datasets/" directory and the legacy locations have been
2020
deprecated (ie: a new major release).
2121
"""
22-
from typing import Text
2322

2423
from absl import logging
2524
from tensorflow_datasets.core import registered
@@ -28,11 +27,12 @@
2827
class LazyBuilderImport:
2928
"""Lazy load DatasetBuilder from given name from legacy locations."""
3029

31-
def __init__(self, dataset_name: Text):
30+
def __init__(self, dataset_name: str):
3231
object.__setattr__(self, "_dataset_name", dataset_name)
3332
object.__setattr__(self, "_dataset_cls", None)
3433

3534
def _get_builder_cls(self):
35+
"""Returns the DatasetBuilder class."""
3636
cls = object.__getattribute__(self, "_dataset_cls")
3737
if not cls:
3838
builder_name = object.__getattribute__(self, "_dataset_name")

tensorflow_datasets/core/registered.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717

1818
import abc
1919
import collections
20+
from collections.abc import Iterator
2021
import contextlib
2122
import functools
2223
import importlib
2324
import inspect
2425
import os.path
25-
from typing import ClassVar, Dict, Iterator, List, Type, Text, Tuple
26+
import time
27+
from typing import ClassVar, Type
2628

29+
from absl import logging
2730
from etils import epath
2831
from tensorflow_datasets.core import constants
2932
from tensorflow_datasets.core import naming
3033
from tensorflow_datasets.core import visibility
34+
import tensorflow_datasets.core.logging as _tfds_logging
35+
from tensorflow_datasets.core.logging import call_metadata as _call_metadata
3136
from tensorflow_datasets.core.utils import py_utils
3237
from tensorflow_datasets.core.utils import resource_utils
3338

@@ -38,7 +43,7 @@
3843
# <str snake_cased_name, abstract DatasetBuilder subclass>
3944
_ABSTRACT_DATASET_REGISTRY = {}
4045

41-
# Keep track of Dict[str (module name), List[DatasetBuilder]]
46+
# Keep track of dict[str (module name), list[DatasetBuilder]]
4247
# This is directly accessed by `tfds.community.builder_cls_from_module` when
4348
# importing community packages.
4449
_MODULE_TO_DATASETS = collections.defaultdict(list)
@@ -51,7 +56,7 @@
5156
# <str snake_cased_name, abstract DatasetCollectionBuilder subclass>
5257
_ABSTRACT_DATASET_COLLECTION_REGISTRY = {}
5358

54-
# Keep track of Dict[str (module name), List[DatasetCollectionBuilder]]
59+
# Keep track of dict[str (module name), list[DatasetCollectionBuilder]]
5560
_MODULE_TO_DATASET_COLLECTIONS = collections.defaultdict(list)
5661

5762
# eg for dataset "foo": "tensorflow_datasets.datasets.foo.foo_dataset_builder".
@@ -80,6 +85,70 @@ def skip_registration() -> Iterator[None]:
8085
_skip_registration = False
8186

8287

88+
@functools.cache
89+
def _import_legacy_builders() -> None:
90+
"""Imports legacy builders."""
91+
modules_to_import = [
92+
'audio',
93+
'graphs',
94+
'image',
95+
'image_classification',
96+
'object_detection',
97+
'nearest_neighbors',
98+
'question_answering',
99+
'd4rl',
100+
'ranking',
101+
'recommendation',
102+
'rl_unplugged',
103+
'rlds.datasets',
104+
'robotics',
105+
'robomimic',
106+
'structured',
107+
'summarization',
108+
'text',
109+
'text_simplification',
110+
'time_series',
111+
'translate',
112+
'video',
113+
'vision_language',
114+
]
115+
116+
before_dataset_imports = time.time()
117+
metadata = _call_metadata.CallMetadata()
118+
metadata.start_time_micros = int(before_dataset_imports * 1e6)
119+
try:
120+
# For builds that don't include all dataset builders, we don't want to fail
121+
# on import errors of dataset builders.
122+
try:
123+
for module in modules_to_import:
124+
importlib.import_module(f'tensorflow_datasets.{module}')
125+
except (ImportError, ModuleNotFoundError):
126+
pass
127+
128+
except Exception as exception: # pylint: disable=broad-except
129+
metadata.mark_error()
130+
logging.exception(exception)
131+
finally:
132+
import_time_ms_dataset_builders = int(
133+
(time.time() - before_dataset_imports) * 1000
134+
)
135+
metadata.mark_end()
136+
_tfds_logging.tfds_import(
137+
metadata=metadata,
138+
import_time_ms_tensorflow=0,
139+
import_time_ms_dataset_builders=import_time_ms_dataset_builders,
140+
)
141+
142+
143+
@functools.cache
144+
def _import_dataset_collections() -> None:
145+
"""Imports dataset collections."""
146+
try:
147+
importlib.import_module('tensorflow_datasets.dataset_collections')
148+
except (ImportError, ModuleNotFoundError):
149+
pass
150+
151+
83152
# The implementation of this class follows closely RegisteredDataset.
84153
class RegisteredDatasetCollection(abc.ABC):
85154
"""Subclasses will be registered and given a `name` property."""
@@ -129,23 +198,24 @@ def __init_subclass__(cls, skip_registration=False, **kwargs): # pylint: disabl
129198
_DATASET_COLLECTION_REGISTRY[cls.name] = cls
130199

131200

132-
def list_imported_dataset_collections() -> List[str]:
201+
def list_imported_dataset_collections() -> list[str]:
133202
"""Returns the string names of all `tfds.core.DatasetCollection`s."""
134-
all_dataset_collections = [
135-
dataset_collection_name
136-
for dataset_collection_name, dataset_collection_cls in _DATASET_COLLECTION_REGISTRY.items()
137-
]
203+
_import_dataset_collections()
204+
all_dataset_collections = list(_DATASET_COLLECTION_REGISTRY.keys())
138205
return sorted(all_dataset_collections)
139206

140207

141208
def is_dataset_collection(name: str) -> bool:
209+
_import_dataset_collections()
142210
return name in _DATASET_COLLECTION_REGISTRY
143211

144212

145213
def imported_dataset_collection_cls(
146214
name: str,
147215
) -> Type[RegisteredDatasetCollection]:
148216
"""Returns the Registered dataset class."""
217+
_import_dataset_collections()
218+
149219
if name in _ABSTRACT_DATASET_COLLECTION_REGISTRY:
150220
raise AssertionError(f'DatasetCollection {name} is an abstract class.')
151221

@@ -224,8 +294,9 @@ def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
224294
return visibility.DatasetType.TFDS_PUBLIC.is_available()
225295

226296

227-
def list_imported_builders() -> List[str]:
297+
def list_imported_builders() -> list[str]:
228298
"""Returns the string names of all `tfds.core.DatasetBuilder`s."""
299+
_import_legacy_builders()
229300
all_builders = [
230301
builder_name
231302
for builder_name, builder_cls in _DATASET_REGISTRY.items()
@@ -236,8 +307,8 @@ def list_imported_builders() -> List[str]:
236307

237308
@functools.lru_cache(maxsize=None)
238309
def _get_existing_dataset_packages(
239-
datasets_dir: Text,
240-
) -> Dict[Text, Tuple[epath.Path, Text]]:
310+
datasets_dir: str,
311+
) -> dict[str, tuple[epath.Path, str]]:
241312
"""Returns existing datasets.
242313
243314
Args:
@@ -293,7 +364,12 @@ def imported_builder_cls(name: str) -> Type[RegisteredDataset]:
293364
raise AssertionError(f'Dataset {name} is an abstract class.')
294365

295366
if name not in _DATASET_REGISTRY:
296-
raise DatasetNotFoundError(f'Dataset {name} not found.')
367+
# Dataset not found in the registry, try to import legacy builders.
368+
# Dataset builders are imported lazily to avoid slowing down the startup
369+
# of the binary.
370+
_import_legacy_builders()
371+
if name not in _DATASET_REGISTRY:
372+
raise DatasetNotFoundError(f'Dataset {name} not found.')
297373

298374
builder_cls = _DATASET_REGISTRY[name]
299375
if not _is_builder_available(builder_cls):

tensorflow_datasets/rlds/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Datasets generated with RLDS."""
16+
"""RLDS utils.
17+
18+
Note that the datasets are imported lazily so they don't need to be in here.
19+
"""
1720

18-
from tensorflow_datasets.rlds import datasets
1921
from tensorflow_datasets.rlds import envlogger_reader
2022
from tensorflow_datasets.rlds import rlds_base

0 commit comments

Comments
 (0)