forked from Gaoyiminggithub/Graphonomy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcihp_pascal_atr.py
219 lines (178 loc) · 7.71 KB
/
cihp_pascal_atr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from __future__ import print_function, division
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from .mypath_cihp import Path
from .mypath_pascal import Path as PP
from .mypath_atr import Path as PA
import random
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class VOCSegmentation(Dataset):
"""
Pascal dataset
"""
def __init__(self,
cihp_dir=Path.db_root_dir('cihp'),
split='train',
transform=None,
flip=False,
pascal_dir = PP.db_root_dir('pascal'),
atr_dir = PA.db_root_dir('atr'),
):
"""
:param cihp_dir: path to CIHP dataset directory
:param pascal_dir: path to PASCAL dataset directory
:param atr_dir: path to ATR dataset directory
:param split: train/val
:param transform: transform to apply
"""
super(VOCSegmentation).__init__()
## for cihp
self._flip_flag = flip
self._base_dir = cihp_dir
self._image_dir = os.path.join(self._base_dir, 'Images')
self._cat_dir = os.path.join(self._base_dir, 'Category_ids')
self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids')
## for Pascal
self._base_dir_pascal = pascal_dir
self._image_dir_pascal = os.path.join(self._base_dir_pascal, 'JPEGImages')
self._cat_dir_pascal = os.path.join(self._base_dir_pascal, 'SegmentationPart')
# self._flip_dir_pascal = os.path.join(self._base_dir_pascal, 'Category_rev_ids')
## for atr
self._base_dir_atr = atr_dir
self._image_dir_atr = os.path.join(self._base_dir_atr, 'JPEGImages')
self._cat_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug')
self._flip_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug_rev')
if isinstance(split, str):
self.split = [split]
else:
split.sort()
self.split = split
self.transform = transform
_splits_dir = os.path.join(self._base_dir, 'lists')
_splits_dir_pascal = os.path.join(self._base_dir_pascal, 'list')
_splits_dir_atr = os.path.join(self._base_dir_atr, 'list')
self.im_ids = []
self.images = []
self.categories = []
self.flip_categories = []
self.datasets_lbl = []
# num
self.num_cihp = 0
self.num_pascal = 0
self.num_atr = 0
# for cihp is 0
for splt in self.split:
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
lines = f.read().splitlines()
self.num_cihp += len(lines)
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir, line+'.jpg' )
_cat = os.path.join(self._cat_dir, line +'.png')
_flip = os.path.join(self._flip_dir,line + '.png')
# print(self._image_dir,_image)
assert os.path.isfile(_image)
# print(_cat)
assert os.path.isfile(_cat)
assert os.path.isfile(_flip)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
self.flip_categories.append(_flip)
self.datasets_lbl.append(0)
# for pascal is 1
for splt in self.split:
if splt == 'test':
splt='val'
with open(os.path.join(os.path.join(_splits_dir_pascal, splt + '_id.txt')), "r") as f:
lines = f.read().splitlines()
self.num_pascal += len(lines)
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir_pascal, line+'.jpg' )
_cat = os.path.join(self._cat_dir_pascal, line +'.png')
# _flip = os.path.join(self._flip_dir,line + '.png')
# print(self._image_dir,_image)
assert os.path.isfile(_image)
# print(_cat)
assert os.path.isfile(_cat)
# assert os.path.isfile(_flip)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
self.flip_categories.append([])
self.datasets_lbl.append(1)
# for atr is 2
for splt in self.split:
with open(os.path.join(os.path.join(_splits_dir_atr, splt + '_id.txt')), "r") as f:
lines = f.read().splitlines()
self.num_atr += len(lines)
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir_atr, line + '.jpg')
_cat = os.path.join(self._cat_dir_atr, line + '.png')
_flip = os.path.join(self._flip_dir_atr, line + '.png')
# print(self._image_dir,_image)
assert os.path.isfile(_image)
# print(_cat)
assert os.path.isfile(_cat)
assert os.path.isfile(_flip)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
self.flip_categories.append(_flip)
self.datasets_lbl.append(2)
assert (len(self.images) == len(self.categories))
# assert len(self.flip_categories) == len(self.categories)
# Display stats
print('Number of images in {}: {:d}'.format(split, len(self.images)))
def __len__(self):
return len(self.images)
def get_class_num(self):
return self.num_cihp,self.num_pascal,self.num_atr
def __getitem__(self, index):
_img, _target,_lbl= self._make_img_gt_point_pair(index)
sample = {'image': _img, 'label': _target,}
if self.transform is not None:
sample = self.transform(sample)
sample['pascal'] = _lbl
return sample
def _make_img_gt_point_pair(self, index):
# Read Image and Target
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
# _target = np.array(Image.open(self.categories[index])).astype(np.float32)
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
type_lbl = self.datasets_lbl[index]
if self._flip_flag:
if random.random() < 0.5 :
# _target = Image.open(self.flip_categories[index])
_img = _img.transpose(Image.FLIP_LEFT_RIGHT)
if type_lbl == 0 or type_lbl == 2:
_target = Image.open(self.flip_categories[index])
else:
_target = Image.open(self.categories[index])
_target = _target.transpose(Image.FLIP_LEFT_RIGHT)
else:
_target = Image.open(self.categories[index])
else:
_target = Image.open(self.categories[index])
return _img, _target,type_lbl
def __str__(self):
return 'datasets(split=' + str(self.split) + ')'
if __name__ == '__main__':
from dataloaders import custom_transforms as tr
from dataloaders.utils import decode_segmap
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
composed_transforms_tr = transforms.Compose([
# tr.RandomHorizontalFlip(),
tr.RandomSized_new(512),
tr.RandomRotate(15),
tr.ToTensor_()])
voc_train = VOCSegmentation(split='train',
transform=composed_transforms_tr)
dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=1)
for ii, sample in enumerate(dataloader):
if ii >10:
break