1
1
from pathlib import Path
2
2
from typing import Optional , List , Callable , Dict , Any , Union
3
+ import warnings
3
4
4
5
import PIL .Image as pil_image
5
6
from torch import Tensor
8
9
9
10
from taming .data .conditional_builder .objects_bbox import ObjectsBoundingBoxConditionalBuilder
10
11
from taming .data .conditional_builder .objects_center_points import ObjectsCenterPointsConditionalBuilder
12
+ from taming .data .conditional_builder .utils import load_object_from_string
11
13
from taming .data .helper_types import BoundingBox , CropMethodType , Image , Annotation , SplitType
12
14
from taming .data .image_transforms import CenterCropReturnCoordinates , RandomCrop1dReturnCoordinates , \
13
15
Random2dCropReturnCoordinates , RandomHorizontalFlipReturn , convert_pil_to_tensor
@@ -17,7 +19,7 @@ class AnnotatedObjectsDataset(Dataset):
17
19
def __init__ (self , data_path : Union [str , Path ], split : SplitType , keys : List [str ], target_image_size : int ,
18
20
min_object_area : float , min_objects_per_image : int , max_objects_per_image : int ,
19
21
crop_method : CropMethodType , random_flip : bool , no_tokens : int , use_group_parameter : bool ,
20
- encode_crop : bool ):
22
+ encode_crop : bool , category_allow_list_target : str , category_mapping_target : str ):
21
23
self .data_path = data_path
22
24
self .split = split
23
25
self .keys = keys
@@ -40,6 +42,12 @@ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str
40
42
self .transform_functions : List [Callable ] = self .setup_transform (target_image_size , crop_method , random_flip )
41
43
self .paths = self .build_paths (self .data_path )
42
44
self ._conditional_builders = None
45
+ if category_allow_list_target :
46
+ allow_list = load_object_from_string (category_allow_list_target )
47
+ self .category_allow_list = {name for name , _ in allow_list }
48
+ self .category_mapping = {}
49
+ if category_mapping_target :
50
+ self .category_mapping = load_object_from_string (category_mapping_target )
43
51
44
52
def build_paths (self , top_level : Union [str , Path ]) -> Dict [str , Path ]:
45
53
top_level = Path (top_level )
@@ -123,12 +131,22 @@ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
123
131
return self ._conditional_builders
124
132
125
133
def filter_categories (self ) -> None :
126
- pass
134
+ if self .category_allow_list :
135
+ self .categories = {id_ : cat for id_ , cat in self .categories .items () if cat .name in self .category_allow_list }
136
+ if self .category_mapping :
137
+ self .categories = {id_ : cat for id_ , cat in self .categories .items () if cat .id not in self .category_mapping }
127
138
128
139
def setup_category_id_and_number (self ) -> None :
129
140
self .category_ids = list (self .categories .keys ())
130
141
self .category_ids .sort ()
142
+ if '/m/01s55n' in self .category_ids :
143
+ self .category_ids .remove ('/m/01s55n' )
144
+ self .category_ids .append ('/m/01s55n' )
131
145
self .category_number = {category_id : i for i , category_id in enumerate (self .category_ids )}
146
+ if self .category_allow_list is not None and self .category_mapping is None \
147
+ and len (self .category_ids ) != len (self .category_allow_list ):
148
+ warnings .warn ('Unexpected number of categories: Mismatch with category_allow_list. '
149
+ 'Make sure all names in category_allow_list exist.' )
132
150
133
151
def clean_up_annotations_and_image_descriptions (self ) -> None :
134
152
image_id_set = set (self .image_ids )
0 commit comments