Skip to content

Add omniglot with new API #5459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,33 @@ def cub200(info, root, config):
return num_samples_map[config.split]


@register_mock
def omniglot(info, root, config):
num_images = {"images_background": 5, "images_evaluation": 5}

split = config.split
create_image_folder(root, f"{split}", None, 0, 0)
alphabets = ["Angelic", "Atemayar_Qelisayer"]
characters = ["character01", "character02", "character03"]

for alphabet in alphabets:
create_image_folder(root / f"{split}", alphabet, None, 0, 0)

i, j = 0, 0
for character in characters:
for alphabet in alphabets:
create_image_folder(
root / f"{split}/{alphabet}", character, lambda idx: f"{idx+i}_{j}.jpg", int(num_images[split])
)
j += 1
i += 1
j = 0

make_zip(root, f"{split}.zip")

return num_images["images_background"]


@register_mock
def svhn(info, root, config):
import scipy.io as sio
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
from .sbd import SBD
Expand Down
85 changes: 85 additions & 0 deletions torchvision/prototype/datasets/_builtin/omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import collections
import io
import pathlib
from typing import Any, Dict, List, Tuple

import numpy as np
from PIL import Image
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Label


class Omniglot(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"omniglot",
type=DatasetType.IMAGE,
homepage="https://github.com/brendenlake/omniglot",
valid_options=dict(split=("images_background", "images_evaluation")),
)

_CHECKSUMS = {
"images_background": "ad41ab679c8b5d90b271ef46be6987cc81211774a695c29dcc5367b2b26ee640",
"images_evaluation": "1f61a8f3366785b057fc117d9228e78a16e3d976c8953b2a10fcc74cf0609cee",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
data = HttpResource(
f"https://raw.githubusercontent.com/brendenlake/omniglot/master/python/{config.split}.zip",
sha256=self._CHECKSUMS[config.split],
)

return [data]

def _get_alphabets_and_characters(self, dp: IterDataPipe):
categories = collections.OrderedDict()
for path, _ in dp:
character = pathlib.Path(path).parents[0].as_posix().split("/")[-1]
alphabet = pathlib.Path(path).parents[1].as_posix().split("/")[-1]
if alphabet not in categories.keys():
categories[alphabet] = [character]
else:
categories[alphabet].append(character)
self._alphabets = list(categories.keys())
self._characters = [categories[alphabet] for alphabet in self._alphabets]
return self._alphabets, self._characters

def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> Tuple[Image.Image, int]:
image_path, image_file = data
alphabet_class = pathlib.Path(image_path).parents[1].as_posix().split("/")[-1]
image = Image.open(image_file, mode="r").convert("L")
idx = self._alphabets.index(alphabet_class)
return image, idx

def _collate_and_decode(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image, image_label = data

label = Label(image_label, category=self._alphabets[image_label])
return dict(image=image, label=label)

def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
self._get_alphabets_and_characters(dp)

dp = Mapper(dp, self._read_images_and_labels)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)

return Mapper(dp, self._collate_and_decode)

def _generate_categories(self, root: pathlib.Path) -> List[str]:
return self._alphabets