Skip to content

Commit 995b9ad

Browse files
committed
added python API (nonfunctional, need to updated paths, etc.)
1 parent 14f0137 commit 995b9ad

File tree

2 files changed

+456
-0
lines changed

2 files changed

+456
-0
lines changed

PythonAPI/coco.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
__author__ = 'tylin'
2+
__version__ = 1.0
3+
# Interface for accessing the Microsoft COCO dataset.
4+
5+
# Microsoft COCO is a large image dataset designed for object detection,
6+
# segmentation, and caption generation. pycocotools is a Python API that
7+
# assists in loading, parsing and visualizing the annotations in COCO.
8+
# Please visit http://mscoco.org/ for more information on COCO, including
9+
# for the data, paper, and tutorials. The exact format of the annotations
10+
# is also described on the COCO website. For example usage of the pycocotools
11+
# please see pycocotools_demo.ipynb. In addition to this API, please download both
12+
# the COCO images and annotations in order to run the demo.
13+
14+
# An alternative to using the API is to load the annotations directly
15+
# into Python dictionary
16+
# Using the API provides additional utility functions. Note that this API
17+
# supports both *instance* and *caption* annotations. In the case of
18+
# captions not all functions are defined (e.g. categories are undefined).
19+
20+
# The following API functions are defined:
21+
# COCO - COCO api class that loads COCO annotation file and prepare data structures.
22+
# decodeMask - Decode binary mask M encoded via run-length encoding.
23+
# encodeMask - Encode binary mask M using run-length encoding.
24+
# getAnnIds - Get ann ids that satisfy given filter conditions.
25+
# getCatIds - Get cat ids that satisfy given filter conditions.
26+
# getImgIds - Get img ids that satisfy given filter conditions.
27+
# loadAnns - Load anns with the specified ids.
28+
# loadCats - Load cats with the specified ids.
29+
# loadImgs - Load imgs with the specified ids.
30+
# segToMask - Convert polygon segmentation to binary mask.
31+
# showAnns - Display the specified annotations.
32+
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
33+
# Help on each functions can be accessed by: "help COCO>function".
34+
35+
# See also COCO>decodeMask,
36+
# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
37+
# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
38+
# COCO>loadImgs, COCO>segToMask, COCO>showAnns
39+
40+
# Microsoft COCO Toolbox. Version 1.0
41+
# Data, paper, and tutorials available at: http://mscoco.org/
42+
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
43+
# Licensed under the Simplified BSD License [see bsd.txt]
44+
45+
import json
46+
import datetime
47+
import itertools
48+
import matplotlib.pyplot as plt
49+
import matplotlib.image as mpimg
50+
import pylab
51+
from matplotlib.collections import PatchCollection
52+
from matplotlib.patches import Polygon
53+
import numpy as np
54+
from skimage.draw import polygon
55+
import copy
56+
57+
class COCO:
58+
def __init__(self, annotation_file='annotations/instances_val2014_1_0.json'):
59+
"""
60+
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
61+
:param annotation_file (str): location of annotation file
62+
:param image_folder (str): location to the folder that hosts images.
63+
:return:
64+
"""
65+
# load dataset
66+
print 'loading annotations into memory...'
67+
time_t = datetime.datetime.utcnow()
68+
dataset = json.load(open(annotation_file, 'r'))
69+
print datetime.datetime.utcnow() - time_t
70+
print 'annotations loaded!'
71+
72+
time_t = datetime.datetime.utcnow()
73+
# create index
74+
print 'creating index...'
75+
imgToAnns = {ann['image_id']: [] for ann in dataset['annotations']}
76+
anns = {ann['id']: [] for ann in dataset['annotations']}
77+
for ann in dataset['annotations']:
78+
imgToAnns[ann['image_id']] += [ann]
79+
anns[ann['id']] = ann
80+
81+
imgs = {im['id']: {} for im in dataset['images']}
82+
for img in dataset['images']:
83+
imgs[img['id']] = img
84+
85+
cats = []
86+
catToImgs = []
87+
if dataset['type'] == 'instances':
88+
cats = {cat['id']: [] for cat in dataset['categories']}
89+
for cat in dataset['categories']:
90+
cats[cat['id']] = cat
91+
catToImgs = {cat['id']: [] for cat in dataset['categories']}
92+
for ann in dataset['annotations']:
93+
catToImgs[ann['category_id']] += [ann['image_id']]
94+
95+
print datetime.datetime.utcnow() - time_t
96+
print 'index created!'
97+
98+
# create class members
99+
self.anns = anns
100+
self.imgToAnns = imgToAnns
101+
self.catToImgs = catToImgs
102+
self.imgs = imgs
103+
self.cats = cats
104+
self.dataset = dataset
105+
106+
107+
def info(self):
108+
"""
109+
Print information about the annotation file.
110+
:return:
111+
"""
112+
for key, value in self.datset['info'].items():
113+
print '%s: %s'%(key, value)
114+
115+
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
116+
"""
117+
Get ann ids that satisfy given filter conditions. default skips that filter
118+
:param imgIds (int array) : get anns for given imgs
119+
catIds (int array) : get anns for given cats
120+
areaRng (float array) : get anns for given area range (e.g. [0 inf])
121+
iscrowd (boolean) : get anns for given crowd label (False or True)
122+
123+
:return: ids (int array) : integer array of ann ids
124+
"""
125+
imgIds = imgIds if type(imgIds) == list else [imgIds]
126+
catIds = catIds if type(catIds) == list else [catIds]
127+
128+
if len(imgIds) == len(catIds) == len(areaRng) == 0:
129+
anns = self.dataset['annotations']
130+
else:
131+
if not len(imgIds) == 0:
132+
anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[])
133+
else:
134+
anns = self.dataset['annotations']
135+
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
136+
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
137+
if self.dataset['type'] == 'instances':
138+
if not iscrowd == None:
139+
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
140+
else:
141+
ids = [ann['id'] for ann in anns]
142+
else:
143+
ids = [ann['id'] for ann in anns]
144+
return ids
145+
146+
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
147+
"""
148+
filtering parameters. default skips that filter.
149+
:param catNms (str array) : get cats for given cat names
150+
:param supNms (str array) : get cats for given supercategory names
151+
:param catIds (int array) : get cats for given cat ids
152+
:return: ids (int array) : integer array of cat ids
153+
"""
154+
catNms = catNms if type(catNms) == list else [catNms]
155+
supNms = supNms if type(supNms) == list else [supNms]
156+
catIds = catIds if type(catIds) == list else [catIds]
157+
158+
if len(catNms) == len(supNms) == len(catIds) == 0:
159+
cats = self.dataset['categories']
160+
else:
161+
cats = self.dataset['categories']
162+
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
163+
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
164+
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
165+
ids = [cat['id'] for cat in cats]
166+
return ids
167+
168+
def getImgIds(self, imgIds=[], catIds=[]):
169+
'''
170+
Get img ids that satisfy given filter conditions.
171+
:param imgIds (int array) : get imgs for given ids
172+
:param catIds (int array) : get imgs with all given cats
173+
:return: ids (int array) : integer array of img ids
174+
'''
175+
imgIds = imgIds if type(imgIds) == list else [imgIds]
176+
catIds = catIds if type(catIds) == list else [catIds]
177+
178+
if len(imgIds) == len(catIds) == 0:
179+
ids = self.imgs.keys()
180+
else:
181+
ids = set(imgIds)
182+
for catId in catIds:
183+
if len(ids) == 0:
184+
ids = set(self.catToImgs[catId])
185+
else:
186+
ids &= set(self.catToImgs[catId])
187+
return list(ids)
188+
189+
def loadAnns(self, ids=[]):
190+
"""
191+
Load anns with the specified ids.
192+
:param ids (int array) : integer ids specifying anns
193+
:return: anns (object array) : loaded ann objects
194+
"""
195+
if type(ids) == list:
196+
return [self.anns[id] for id in ids]
197+
elif type(ids) == int:
198+
return [self.anns[ids]]
199+
200+
def loadCats(self, ids=[]):
201+
"""
202+
Load cats with the specified ids.
203+
:param ids (int array) : integer ids specifying cats
204+
:return: cats (object array) : loaded cat objects
205+
"""
206+
if type(ids) == list:
207+
return [self.cats[id] for id in ids]
208+
elif type(ids) == int:
209+
return [self.cats[ids]]
210+
211+
def loadImgs(self, ids=[]):
212+
"""
213+
Load anns with the specified ids.
214+
:param ids (int array) : integer ids specifying img
215+
:return: imgs (object array) : loaded img objects
216+
"""
217+
if type(ids) == list:
218+
return [self.imgs[id] for id in ids]
219+
elif type(ids) == int:
220+
return [self.imgs[ids]]
221+
222+
def getImageIds(self, params={}):
223+
"""
224+
Get image IDs from annotations. One can use params to get filtered results.
225+
:param params (dict): { 'cat_id': [int]}
226+
Filter images that contain specified object category.
227+
If params is empty, return all image IDs in the dataset.
228+
:return: a list of image IDs
229+
"""
230+
# load all images if no constraint specified
231+
if params == {}:
232+
return self.imgs.keys()
233+
# get instances with filtering constraints
234+
im_id_lists = []
235+
# specific filtering for instances annotations
236+
if self.ann_key == 'instances' and 'cat_id' in params.keys():
237+
im_id_lists.append( [ann['image_id'] for ann_id, ann in self.annotations.items() if ann['category_id'] == params['cat_id']] )
238+
# aggregate the queries by AND operation
239+
if len(im_id_lists) == 0:
240+
im_id_list = []
241+
for i, l in enumerate(im_id_lists):
242+
assert isinstance(l, list)
243+
im_id_list = set(im_id_list) & set(l) if not i == 0 else set(l)
244+
return list(im_id_list)
245+
246+
def showAnns(self, anns):
247+
"""
248+
Display the specified annotations.
249+
:param anns (array of object): annotations to display
250+
:return: None
251+
"""
252+
if len(anns) == 0:
253+
return 0
254+
if self.dataset['type'] == 'instances':
255+
ax = plt.gca()
256+
polygons = []
257+
color = []
258+
for ann in anns:
259+
c = np.random.random((1, 3)).tolist()[0]
260+
if not ann['iscrowd']:
261+
# polygon
262+
for seg in ann['segmentation']:
263+
poly = np.array(seg).reshape((len(seg)/2, 2))
264+
polygons.append(Polygon(poly, True,alpha=0.4))
265+
color.append(c)
266+
else:
267+
# mask
268+
mask = COCO.decodeMask(ann['segmentation'])
269+
img = np.ones( (mask.shape[0], mask.shape[1], 3) )
270+
img[:,:,:] = 64
271+
# for i in range(3):
272+
# img[:,:,i] *= c[i]*255
273+
ax.imshow(np.dstack( (img, mask*0.5) ))
274+
p = PatchCollection(polygons, facecolors=color, edgecolors=(0,0,0,1), linewidths=3, alpha=0.4)
275+
ax.add_collection(p)
276+
if self.dataset['type'] == 'captions':
277+
for ann in anns:
278+
print ann['caption']
279+
280+
281+
@staticmethod
282+
def decodeMask(R):
283+
"""
284+
Decode binary mask M encoded via run-length encoding.
285+
:param R (object RLE) : run-length encoding of binary mask
286+
:return: M (bool 2D array) : decoded binary mask
287+
"""
288+
N = len(R['counts'])
289+
M = np.zeros( (R['size'][0]*R['size'][1], ))
290+
n = 0
291+
val = 1
292+
for pos in range(N):
293+
val = not val
294+
for c in range(R['counts'][pos]):
295+
R['counts'][pos]
296+
M[n] = val
297+
n += 1
298+
return M.reshape((R['size']), order='F')
299+
300+
@staticmethod
301+
def encodeMask(M):
302+
"""
303+
Encode binary mask M using run-length encoding.
304+
:param M (bool 2D array) : binary mask to encode
305+
:return: R (object RLE) : run-length encoding of binary mask
306+
"""
307+
[h, w] = M.shape
308+
M = M.flatten(order='F')
309+
N = len(M)
310+
counts_list = []
311+
pos = 0
312+
# counts
313+
counts_list.append(1)
314+
diffs = np.logical_xor(M[0:N-1], M[1:N])
315+
for diff in diffs:
316+
if diff:
317+
pos +=1
318+
counts_list.append(1)
319+
else:
320+
counts_list[pos] += 1
321+
# if array starts from 1. start with 0 counts for 0
322+
if M[0] == 1:
323+
counts_list = [0] + counts_list
324+
return {'size': [h, w],
325+
'counts': counts_list ,
326+
}
327+
328+
@staticmethod
329+
def segToMask( S, h, w ):
330+
"""
331+
Convert polygon segmentation to binary mask.
332+
:param S (float array) : polygon segmentation mask
333+
:param h (int) : target mask height
334+
:param w (int) : target mask width
335+
:return: M (bool 2D array) : binary mask
336+
"""
337+
M = np.zeros((h,w), dtype=np.bool)
338+
for s in S:
339+
N = len(s)
340+
rr, cc = polygon(np.array(s[1:N:2]), np.array(s[0:N:2])) # (y, x)
341+
M[rr, cc] = 1
342+
return M

0 commit comments

Comments
 (0)