Skip to content

Commit 0628cb0

Browse files
committed
Add COCODetection starter code
1 parent 2196f7b commit 0628cb0

File tree

4 files changed

+105
-1
lines changed

4 files changed

+105
-1
lines changed

data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES
2+
from .coco import COCODetection, COCOAnnotationTransform
23
from .config import *
34
import cv2
45
import numpy as np

data/coco.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
class COCOAnnotationTransform(object):
2+
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
3+
Initilized with a dictionary lookup of classnames to indexes
4+
5+
Arguments:
6+
class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
7+
(default: alphabetic indexing of VOC's 20 classes)
8+
keep_difficult (bool, optional): keep difficult instances or not
9+
(default: False)
10+
height (int): height
11+
width (int): width
12+
"""
13+
14+
# def __init__(self)
15+
16+
def __call__(self, target, width, height):
17+
"""
18+
Arguments:
19+
target (annotation) : the target annotation to be made usable
20+
will be an ET.Element
21+
Returns:
22+
a list containing lists of bounding boxes [bbox coords, class name]
23+
"""
24+
scale = np.array([width, height, width, height])
25+
res = []
26+
for obj in target:
27+
if 'bbox' in obj:
28+
bbox = obj['bbox']
29+
bbox[2] += bbox[0]
30+
bbox[3] += bbox[1]
31+
label_idx = obj['category_id']
32+
final_box = list(np.array(bbox)/scale)
33+
final_box.append(label_idx)
34+
res += [final_box] # [xmin, ymin, xmax, ymax, label_ind]
35+
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
36+
37+
38+
class COCODetection(data.Dataset):
39+
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
40+
Args:
41+
root (string): Root directory where images are downloaded to.
42+
annFile (string): Path to json annotation file.
43+
transform (callable, optional): A function/transform that takes in an PIL image
44+
and returns a transformed version. E.g, ``transforms.ToTensor``
45+
target_transform (callable, optional): A function/transform that takes in the
46+
target and transforms it.
47+
"""
48+
49+
def __init__(self, root, annFile, transform=None, target_transform=None):
50+
from pycocotools.coco import COCO
51+
self.root = root
52+
self.coco = COCO(annFile)
53+
self.ids = list(self.coco.imgs.keys())
54+
self.transform = transform
55+
self.target_transform = target_transform
56+
57+
def __getitem__(self, index):
58+
"""
59+
Args:
60+
index (int): Index
61+
Returns:
62+
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
63+
"""
64+
im, gt, h, w = self.pull_item(index)
65+
return im, gt
66+
67+
def __len__(self):
68+
return len(self.ids)
69+
70+
def pull_item(self, index):
71+
"""
72+
Args:
73+
index (int): Index
74+
Returns:
75+
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
76+
"""
77+
coco = self.coco
78+
img_id = self.ids[index]
79+
ann_ids = coco.getAnnIds(imgIds=img_id)
80+
target = coco.loadAnns(ann_ids)
81+
path = coco.loadImgs(img_id)[0]['file_name']
82+
img = cv2.imread(os.path.join(self.root, path))
83+
height, width, channels = img.shape
84+
if self.target_transform is not None:
85+
target = self.target_transform(target, width, height)
86+
if self.transform is not None:
87+
target = np.array(target)
88+
img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
89+
# to rgb
90+
img = img[:, :, (2, 1, 0)]
91+
# img = img.transpose(2, 0, 1)
92+
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
93+
return torch.from_numpy(img).permute(2, 0, 1), target, height, width
94+
95+
def __repr__(self):
96+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
97+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
98+
fmt_str += ' Root Location: {}\n'.format(self.root)
99+
tmp = ' Transforms (if any): '
100+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
101+
tmp = ' Target Transforms (if any): '
102+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
103+
return fmt_str

data/voc0712.py

-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def __init__(self, root, image_sets, transform=None, target_transform=None,
115115

116116
def __getitem__(self, index):
117117
im, gt, h, w = self.pull_item(index)
118-
119118
return im, gt
120119

121120
def __len__(self):

utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .augmentations import SSDAugmentation

0 commit comments

Comments
 (0)