Skip to content

add prototype for HMDB51 dataset #4541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ ignore_missing_imports = True
[mypy-torchdata.*]

ignore_missing_imports = True

[mypy-rarfile.*]

ignore_missing_imports = True
8 changes: 4 additions & 4 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import io
from typing import Any, Callable, Dict, List, Optional

import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.decoder import av, raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils._internal import add_suggestion

Expand Down Expand Up @@ -50,16 +49,17 @@ def info(name: str) -> DatasetInfo:

default = object()

DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], Dict[str, Any]]] = {
DatasetType.RAW: raw,
DatasetType.IMAGE: pil,
DatasetType.VIDEO: av,
}


def load(
name: str,
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = default, # type: ignore[assignment]
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]] = default, # type: ignore[assignment]
split: str = "train",
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .hdmb51 import HMDB51
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
from .voc import VOC
19 changes: 10 additions & 9 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
return category, id

def _collate_and_decode_sample(
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
self, data, *, decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]]
) -> Dict[str, Any]:
key, image_data, ann_data = data
category, _ = key
Expand All @@ -93,28 +93,27 @@ def _collate_and_decode_sample(

label = self.info.categories.index(category)

image = decoder(image_buffer) if decoder else image_buffer

ann = read_mat(ann_buffer)
bbox = torch.as_tensor(ann["box_coord"].astype(np.int64))
contour = torch.as_tensor(ann["obj_contour"])

return dict(
sample = dict(
category=category,
label=label,
image=image,
image_path=image_path,
bbox=bbox,
contour=contour,
ann_path=ann_path,
)
sample.update(decoder(image_buffer) if decoder else dict(image=image_buffer))
return sample

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps

Expand Down Expand Up @@ -169,22 +168,24 @@ def _collate_and_decode_sample(
self,
data: Tuple[str, io.IOBase],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]],
) -> Dict[str, Any]:
path, buffer = data

dir_name = pathlib.Path(path).parent.name
label_str, category = dir_name.split(".")
label = torch.tensor(int(label_str))

return dict(label=label, category=category, image=decoder(buffer) if decoder else buffer)
sample = dict(label=label, category=category)
sample.update(decoder(buffer) if decoder else dict(image=buffer))
return sample

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
Expand Down
141 changes: 141 additions & 0 deletions torchvision/prototype/datasets/_builtin/hdmb51.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
Mapper,
Shuffler,
Filter,
)
from torchdata.datapipes.iter import KeyZipper, CSVParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
INFINITE_BUFFER_SIZE,
RarArchiveReader,
)

HERE = pathlib.Path(__file__).parent


class HMDB51(Dataset):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"hmdb51",
type="video",
categories=HERE / "hmdb51.categories",
homepage="https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/",
valid_options=dict(
split=("train", "test"),
split_number=("1", "2", "3"),
),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
splits = HttpResource(
"http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
sha256="229c94f845720d01eb3946d39f39292ea962d50a18136484aa47c1eba251d2b7",
)
videos = HttpResource(
"http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar",
sha256="9e714a0d8b76104d76e932764a7ca636f929fff66279cda3f2e326fa912a328e",
)
return [splits, videos]

_SPLIT_FILE_PATTERN = re.compile(r"(?P<category>\w+?)_test_split(?P<split_number>[1-3])[.]txt")

def _is_split_number(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
split_number = self._SPLIT_FILE_PATTERN.match(path.name).group("split_number") # type: ignore[union-attr]
return split_number == config.split_number

_SPLIT_ID_TO_NAME = {
"1": "train",
"2": "test",
}

def _is_split(self, data: List[str], *, config=DatasetConfig) -> bool:
split_id = data[1]
if split_id not in self._SPLIT_ID_TO_NAME:
return False
return self._SPLIT_ID_TO_NAME[split_id] == config.split

def _splits_key(self, data: List[str]) -> str:
return data[0]

def _videos_key(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
return path.name

def _collate_and_decode_sample(
self, data: Tuple[List[str], Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]]
) -> Dict[str, Any]:
_, video_data = data
path, buffer = video_data

category = pathlib.Path(path).parent.name
label = torch.tensor(self.info.categories.index(category))

sample = dict(
path=path,
category=category,
label=label,
)

sample.update(decoder(buffer) if decoder else dict(video=buffer))
return sample

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]],
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, videos_dp = resource_dps

splits_dp = RarArchiveReader(splits_dp)
splits_dp = Filter(splits_dp, self._is_split_number, fn_kwargs=dict(config=config))
splits_dp = CSVParser(splits_dp, delimiter=" ")
splits_dp = Filter(splits_dp, self._is_split, fn_kwargs=dict(config=config))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)

videos_dp = RarArchiveReader(videos_dp)
videos_dp = RarArchiveReader(videos_dp)

dp = KeyZipper(
splits_dp,
videos_dp,
key_fn=self._splits_key,
ref_key_fn=self._videos_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
splits_archive = self.resources(self.default_config)[0]
dp = splits_archive.to_datapipe(pathlib.Path(root) / self.name)
dp = RarArchiveReader(dp)

categories = {
self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name).group("category") # type: ignore[union-attr]
for path, _ in dp
}
create_categories_file(HERE, self.name, sorted(categories))


if __name__ == "__main__":
from torchvision.prototype.datasets import home

root = home()
HMDB51().generate_categories_file(root)
51 changes: 51 additions & 0 deletions torchvision/prototype/datasets/_builtin/hmdb51.categories
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
brush_hair
cartwheel
catch
chew
clap
climb
climb_stairs
dive
draw_sword
dribble
drink
eat
fall_floor
fencing
flic_flac
golf
handstand
hit
hug
jump
kick
kick_ball
kiss
laugh
pick
pour
pullup
punch
push
pushup
ride_bike
ride_horse
run
shake_hands
shoot_ball
shoot_bow
shoot_gun
sit
situp
smile
smoke
somersault
stand
swing_baseball
sword
sword_exercise
talk
throw
turn
walk
wave
14 changes: 8 additions & 6 deletions torchvision/prototype/datasets/_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,26 @@ def _collate_and_decode_data(
*,
root: pathlib.Path,
categories: List[str],
decoder,
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]],
) -> Dict[str, Any]:
path, buffer = data
data = decoder(buffer) if decoder else buffer

category = pathlib.Path(path).relative_to(root).parts[0]
label = torch.tensor(categories.index(category))
return dict(

sample = dict(
path=path,
data=data,
label=label,
category=category,
)
sample.update(decoder(buffer) if decoder else dict(data=buffer))
return sample


def from_data_folder(
root: Union[str, pathlib.Path],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]] = None,
valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True,
) -> Tuple[IterDataPipe, List[str]]:
Expand All @@ -67,7 +69,7 @@ def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]:
def from_image_folder(
root: Union[str, pathlib.Path],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
decoder: Optional[Callable[[io.IOBase], Dict[str, Any]]] = pil,
valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"),
**kwargs: Any,
) -> Tuple[IterDataPipe, List[str]]:
Expand Down
22 changes: 17 additions & 5 deletions torchvision/prototype/datasets/decoder.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import io
import unittest.mock
from typing import Dict, Any

import PIL.Image
import torch
from torchvision.io.video import read_video
from torchvision.transforms.functional import pil_to_tensor

__all__ = ["raw", "pil"]
__all__ = ["raw", "pil", "av"]


def raw(buffer: io.IOBase) -> torch.Tensor:
def raw(buffer: io.IOBase) -> Dict[str, Any]:
raise RuntimeError("This is just a sentinel and should never be called.")


def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor:
return pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper()))
def pil(buffer: io.IOBase, *, mode: str = "RGB") -> Dict[str, Any]:
return dict(image=pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())))


def av(buffer: io.IOBase, **read_video_kwargs: Any) -> Dict[str, Any]:
with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
return dict(
zip(
("video", "audio", "video_meta"),
read_video(buffer, **read_video_kwargs), # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want us to step back and think a bit more about this. We have the opportunity to improve on the video datasets now that we are using IterDataPipes.

)
)
Loading