forked from Gaoyiminggithub/Graphonomy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpascal.py
106 lines (79 loc) · 3.2 KB
/
pascal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import print_function, division
import os
from PIL import Image
from torch.utils.data import Dataset
from .mypath_pascal import Path
class VOCSegmentation(Dataset):
"""
Pascal dataset
"""
def __init__(self,
base_dir=Path.db_root_dir('pascal'),
split='train',
transform=None
):
"""
:param base_dir: path to PASCAL dataset directory
:param split: train/val
:param transform: transform to apply
"""
super(VOCSegmentation).__init__()
self._base_dir = base_dir
self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
self._cat_dir = os.path.join(self._base_dir, 'SegmentationPart')
if isinstance(split, str):
self.split = [split]
else:
split.sort()
self.split = split
self.transform = transform
_splits_dir = os.path.join(self._base_dir, 'list')
self.im_ids = []
self.images = []
self.categories = []
for splt in self.split:
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir, line+'.jpg' )
_cat = os.path.join(self._cat_dir, line +'.png')
# print(self._image_dir,_image)
assert os.path.isfile(_image)
# print(_cat)
assert os.path.isfile(_cat)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
assert (len(self.images) == len(self.categories))
# Display stats
print('Number of images in {}: {:d}'.format(split, len(self.images)))
def __len__(self):
return len(self.images)
def __getitem__(self, index):
_img, _target= self._make_img_gt_point_pair(index)
sample = {'image': _img, 'label': _target}
if self.transform is not None:
sample = self.transform(sample)
return sample
def _make_img_gt_point_pair(self, index):
# Read Image and Target
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
# _target = np.array(Image.open(self.categories[index])).astype(np.float32)
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
_target = Image.open(self.categories[index])
return _img, _target
def __str__(self):
return 'PASCAL(split=' + str(self.split) + ')'
class test_segmentation(VOCSegmentation):
def __init__(self,base_dir=Path.db_root_dir('pascal'),
split='train',
transform=None,
flip=True):
super(test_segmentation, self).__init__(base_dir=base_dir,split=split,transform=transform)
self._flip_flag = flip
def __getitem__(self, index):
_img, _target= self._make_img_gt_point_pair(index)
sample = {'image': _img, 'label': _target}
if self.transform is not None:
sample = self.transform(sample)
return sample