Skip to content

Commit 6e6303a

Browse files
committed
Small fixes for Open Images dataset 🔧
1 parent 70a17a5 commit 6e6303a

5 files changed

+48
-15
lines changed

configs/coco_scene_images_transformer.yaml

+2-3
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,12 @@ data:
4848
target: main.DataModuleFromConfig
4949
params:
5050
batch_size: 6
51-
num_workers: 12
5251
train:
5352
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
5453
params:
5554
data_path: data/coco_annotations_100 # substitute with path to full dataset
5655
split: train
57-
keys: [image, objects_bbox, file_name]
56+
keys: [image, objects_bbox, file_name, annotations]
5857
no_tokens: 8192
5958
target_image_size: 256
6059
min_object_area: 0.00001
@@ -69,7 +68,7 @@ data:
6968
params:
7069
data_path: data/coco_annotations_100 # substitute with path to full dataset
7170
split: validation
72-
keys: [image, objects_bbox, file_name]
71+
keys: [image, objects_bbox, file_name, annotations]
7372
no_tokens: 8192
7473
target_image_size: 256
7574
min_object_area: 0.00001

configs/open_images_scene_images_transformer.yaml

+9-6
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ model:
88
params:
99
vocab_size: 8192
1010
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11-
n_layer: 40
11+
n_layer: 36
1212
n_head: 16
13-
n_embd: 1408
13+
n_embd: 1536
1414
embd_pdrop: 0.1
1515
resid_pdrop: 0.1
1616
attn_pdrop: 0.1
@@ -48,15 +48,16 @@ data:
4848
target: main.DataModuleFromConfig
4949
params:
5050
batch_size: 6
51-
num_workers: 12
5251
train:
5352
target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
5453
params:
5554
data_path: data/open_images_annotations_100 # substitute with path to full dataset
5655
split: train
57-
keys: [image, objects_bbox, file_name]
56+
keys: [image, objects_bbox, file_name, annotations]
5857
no_tokens: 8192
5958
target_image_size: 256
59+
category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
60+
category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
6061
min_object_area: 0.0001
6162
min_objects_per_image: 2
6263
max_objects_per_image: 30
@@ -65,13 +66,15 @@ data:
6566
use_group_parameter: true
6667
encode_crop: true
6768
validation:
68-
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
69+
target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
6970
params:
7071
data_path: data/open_images_annotations_100 # substitute with path to full dataset
7172
split: validation
72-
keys: [image, objects_bbox, file_name]
73+
keys: [image, objects_bbox, file_name, annotations]
7374
no_tokens: 8192
7475
target_image_size: 256
76+
category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
77+
category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
7578
min_object_area: 0.0001
7679
min_objects_per_image: 2
7780
max_objects_per_image: 30

taming/data/annotated_objects_dataset.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22
from typing import Optional, List, Callable, Dict, Any, Union
3+
import warnings
34

45
import PIL.Image as pil_image
56
from torch import Tensor
@@ -8,6 +9,7 @@
89

910
from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
1011
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
12+
from taming.data.conditional_builder.utils import load_object_from_string
1113
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
1214
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
1315
Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
@@ -17,7 +19,7 @@ class AnnotatedObjectsDataset(Dataset):
1719
def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
1820
min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
1921
crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
20-
encode_crop: bool):
22+
encode_crop: bool, category_allow_list_target: str, category_mapping_target: str):
2123
self.data_path = data_path
2224
self.split = split
2325
self.keys = keys
@@ -40,6 +42,12 @@ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str
4042
self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
4143
self.paths = self.build_paths(self.data_path)
4244
self._conditional_builders = None
45+
if category_allow_list_target:
46+
allow_list = load_object_from_string(category_allow_list_target)
47+
self.category_allow_list = {name for name, _ in allow_list}
48+
self.category_mapping = {}
49+
if category_mapping_target:
50+
self.category_mapping = load_object_from_string(category_mapping_target)
4351

4452
def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
4553
top_level = Path(top_level)
@@ -123,12 +131,22 @@ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
123131
return self._conditional_builders
124132

125133
def filter_categories(self) -> None:
126-
pass
134+
if self.category_allow_list:
135+
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
136+
if self.category_mapping:
137+
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
127138

128139
def setup_category_id_and_number(self) -> None:
129140
self.category_ids = list(self.categories.keys())
130141
self.category_ids.sort()
142+
if '/m/01s55n' in self.category_ids:
143+
self.category_ids.remove('/m/01s55n')
144+
self.category_ids.append('/m/01s55n')
131145
self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
146+
if self.category_allow_list is not None and self.category_mapping is None \
147+
and len(self.category_ids) != len(self.category_allow_list):
148+
warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
149+
'Make sure all names in category_allow_list exist.')
132150

133151
def clean_up_annotations_and_image_descriptions(self) -> None:
134152
image_id_set = set(self.image_ids)

taming/data/annotated_objects_open_images.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
}
3434

3535

36-
def load_annotations(descriptor_path: Path, min_object_area: float, category_no_for_id: Dict[str, int]) -> \
37-
Dict[str, List[Annotation]]:
36+
def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
37+
category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
3838
annotations: Dict[str, List[Annotation]] = defaultdict(list)
3939
with open(descriptor_path) as file:
4040
reader = DictReader(file)
@@ -43,6 +43,8 @@ def load_annotations(descriptor_path: Path, min_object_area: float, category_no_
4343
height = float(row['YMax']) - float(row['YMin'])
4444
area = width * height
4545
category_id = row['LabelName']
46+
if category_id in category_mapping:
47+
category_id = category_mapping[category_id]
4648
if area >= min_object_area and category_id in category_no_for_id:
4749
annotations[row['ImageID']].append(
4850
Annotation(
@@ -114,7 +116,8 @@ def __init__(self, **kwargs):
114116
self.setup_category_id_and_number()
115117

116118
self.image_descriptions = {}
117-
annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_number)
119+
annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
120+
self.category_number)
118121
self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
119122
self.max_objects_per_image)
120123
self.image_ids = list(self.annotations.keys())
@@ -129,4 +132,5 @@ def get_image_path(self, image_id: str) -> Path:
129132
return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
130133

131134
def get_image_description(self, image_id: str) -> Dict[str, Any]:
132-
return {'file_path': str(self.get_image_path(image_id))}
135+
image_path = self.get_image_path(image_id)
136+
return {'file_path': str(image_path), 'file_name': image_path.name}

taming/data/conditional_builder/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
from typing import List, Any, Tuple, Optional
23

34
from taming.data.helper_types import BoundingBox, Annotation
@@ -94,3 +95,11 @@ def get_circle_size(figure_size: Tuple[int, int]) -> int:
9495
if max(figure_size) >= 512:
9596
circle_size = 4
9697
return circle_size
98+
99+
100+
def load_object_from_string(object_string: str) -> Any:
101+
"""
102+
Source: https://stackoverflow.com/a/10773699
103+
"""
104+
module_name, class_name = object_string.rsplit(".", 1)
105+
return getattr(importlib.import_module(module_name), class_name)

0 commit comments

Comments
 (0)