|
14 | 14 | __all__ = ["wrap_dataset_for_transforms_v2"]
|
15 | 15 |
|
16 | 16 |
|
17 |
| -# TODO: naming! |
18 | 17 | def wrap_dataset_for_transforms_v2(dataset):
|
| 18 | + """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. |
| 19 | +
|
| 20 | + .. v2betastatus:: wrap_dataset_for_transforms_v2 function |
| 21 | +
|
| 22 | + Example: |
| 23 | + >>> dataset = torchvision.datasets.CocoDetection(...) |
| 24 | + >>> dataset = wrap_dataset_for_transforms_v2(dataset) |
| 25 | +
|
| 26 | + .. note:: |
| 27 | +
|
| 28 | + For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset |
| 29 | + configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you |
| 30 | + to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so. |
| 31 | +
|
| 32 | + The dataset samples are wrapped according to the description below. |
| 33 | +
|
| 34 | + Special cases: |
| 35 | +
|
| 36 | + * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper |
| 37 | + returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), |
| 38 | + ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. |
| 39 | + The original keys are preserved. |
| 40 | + * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to |
| 41 | + the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are |
| 42 | + preserved. |
| 43 | + * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` |
| 44 | + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. |
| 45 | + * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict |
| 46 | + of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data |
| 47 | + in the corresponding ``torchvision.datapoints``. The original keys are preserved. |
| 48 | + * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a |
| 49 | + :class:`~torchvision.datapoints.Mask` datapoint. |
| 50 | + * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a |
| 51 | + :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by |
| 52 | + a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and |
| 53 | + ``"labels"``. |
| 54 | + * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` |
| 55 | + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. |
| 56 | +
|
| 57 | + Image classification datasets |
| 58 | +
|
| 59 | + This wrapper is a no-op for image classification datasets, since they were already fully supported by |
| 60 | + :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`. |
| 61 | +
|
| 62 | + Segmentation datasets |
| 63 | +
|
| 64 | + Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of |
| 65 | + :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the |
| 66 | + segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item). |
| 67 | +
|
| 68 | + Video classification datasets |
| 69 | +
|
| 70 | + Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a |
| 71 | + :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a |
| 72 | + :class:`~torchvision.datapoints.Video` while leaving the other items as is. |
| 73 | +
|
| 74 | + .. note:: |
| 75 | +
|
| 76 | + Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative |
| 77 | + ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`. |
| 78 | +
|
| 79 | + Args: |
| 80 | + dataset: the dataset instance to wrap for compatibility with transforms v2. |
| 81 | + """ |
19 | 82 | return VisionDatasetDatapointWrapper(dataset)
|
20 | 83 |
|
21 | 84 |
|
|
0 commit comments