Skip to content

Commit c50f1e4

Browse files
authored
Add patch to ignore 'licenses', 'info', 'type'
1 parent c745e39 commit c50f1e4

File tree

1 file changed

+371
-0
lines changed

1 file changed

+371
-0
lines changed

pycocoevalcap/coco.py

+371
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
__author__ = 'tylin'
2+
__version__ = '1.0.1'
3+
# Interface for accessing the Microsoft COCO dataset.
4+
5+
# Microsoft COCO is a large image dataset designed for object detection,
6+
# segmentation, and caption generation. pycocotools is a Python API that
7+
# assists in loading, parsing and visualizing the annotations in COCO.
8+
# Please visit http://mscoco.org/ for more information on COCO, including
9+
# for the data, paper, and tutorials. The exact format of the annotations
10+
# is also described on the COCO website. For example usage of the pycocotools
11+
# please see pycocotools_demo.ipynb. In addition to this API, please download both
12+
# the COCO images and annotations in order to run the demo.
13+
14+
# An alternative to using the API is to load the annotations directly
15+
# into Python dictionary
16+
# Using the API provides additional utility functions. Note that this API
17+
# supports both *instance* and *caption* annotations. In the case of
18+
# captions not all functions are defined (e.g. categories are undefined).
19+
20+
# The following API functions are defined:
21+
# COCO - COCO api class that loads COCO annotation file and prepare data structures.
22+
# decodeMask - Decode binary mask M encoded via run-length encoding.
23+
# encodeMask - Encode binary mask M using run-length encoding.
24+
# getAnnIds - Get ann ids that satisfy given filter conditions.
25+
# getCatIds - Get cat ids that satisfy given filter conditions.
26+
# getImgIds - Get img ids that satisfy given filter conditions.
27+
# loadAnns - Load anns with the specified ids.
28+
# loadCats - Load cats with the specified ids.
29+
# loadImgs - Load imgs with the specified ids.
30+
# segToMask - Convert polygon segmentation to binary mask.
31+
# showAnns - Display the specified annotations.
32+
# loadRes - Load result file and create result api object.
33+
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
34+
# Help on each functions can be accessed by: "help COCO>function".
35+
36+
# See also COCO>decodeMask,
37+
# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
38+
# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
39+
# COCO>loadImgs, COCO>segToMask, COCO>showAnns
40+
41+
# Microsoft COCO Toolbox. Version 1.0
42+
# Data, paper, and tutorials available at: http://mscoco.org/
43+
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
44+
# Licensed under the Simplified BSD License [see bsd.txt]
45+
46+
import json
47+
import datetime
48+
import matplotlib.pyplot as plt
49+
from matplotlib.collections import PatchCollection
50+
from matplotlib.patches import Polygon
51+
import numpy as np
52+
from skimage.draw import polygon
53+
import copy
54+
import pdb
55+
56+
class COCO:
57+
def __init__(self, annotation_file=None):
58+
"""
59+
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
60+
:param annotation_file (str): location of annotation file
61+
:param image_folder (str): location to the folder that hosts images.
62+
:return:
63+
"""
64+
# load dataset
65+
self.dataset = {}
66+
self.anns = []
67+
self.imgToAnns = {}
68+
self.catToImgs = {}
69+
self.imgs = []
70+
self.cats = []
71+
if not annotation_file == None:
72+
print 'loading annotations into memory...'
73+
time_t = datetime.datetime.utcnow()
74+
dataset = json.load(open(annotation_file, 'r'))
75+
print datetime.datetime.utcnow() - time_t
76+
self.dataset = dataset
77+
self.createIndex()
78+
79+
def createIndex(self):
80+
# create index
81+
print 'creating index...'
82+
imgToAnns = {ann['image_id']: [] for ann in self.dataset['annotations']}
83+
anns = {ann['id']: [] for ann in self.dataset['annotations']}
84+
for ann in self.dataset['annotations']:
85+
imgToAnns[ann['image_id']] += [ann]
86+
anns[ann['id']] = ann
87+
88+
imgs = {im['id']: {} for im in self.dataset['images']}
89+
for img in self.dataset['images']:
90+
imgs[img['id']] = img
91+
92+
cats = []
93+
catToImgs = []
94+
95+
# if self.dataset['type'] == 'instances':
96+
# cats = {cat['id']: [] for cat in self.dataset['categories']}
97+
# for cat in self.dataset['categories']:
98+
# cats[cat['id']] = cat
99+
# catToImgs = {cat['id']: [] for cat in self.dataset['categories']}
100+
# for ann in self.dataset['annotations']:
101+
# catToImgs[ann['category_id']] += [ann['image_id']]
102+
103+
print 'index created!'
104+
105+
# create class members
106+
self.anns = anns
107+
self.imgToAnns = imgToAnns
108+
self.catToImgs = catToImgs
109+
self.imgs = imgs
110+
self.cats = cats
111+
112+
def info(self):
113+
"""
114+
Print information about the annotation file.
115+
:return:
116+
"""
117+
for key, value in self.datset['info'].items():
118+
print '%s: %s'%(key, value)
119+
120+
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
121+
"""
122+
Get ann ids that satisfy given filter conditions. default skips that filter
123+
:param imgIds (int array) : get anns for given imgs
124+
catIds (int array) : get anns for given cats
125+
areaRng (float array) : get anns for given area range (e.g. [0 inf])
126+
iscrowd (boolean) : get anns for given crowd label (False or True)
127+
:return: ids (int array) : integer array of ann ids
128+
"""
129+
imgIds = imgIds if type(imgIds) == list else [imgIds]
130+
catIds = catIds if type(catIds) == list else [catIds]
131+
132+
if len(imgIds) == len(catIds) == len(areaRng) == 0:
133+
anns = self.dataset['annotations']
134+
else:
135+
if not len(imgIds) == 0:
136+
anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[])
137+
else:
138+
anns = self.dataset['annotations']
139+
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
140+
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
141+
if self.dataset['type'] == 'instances':
142+
if not iscrowd == None:
143+
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
144+
else:
145+
ids = [ann['id'] for ann in anns]
146+
else:
147+
ids = [ann['id'] for ann in anns]
148+
return ids
149+
150+
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
151+
"""
152+
filtering parameters. default skips that filter.
153+
:param catNms (str array) : get cats for given cat names
154+
:param supNms (str array) : get cats for given supercategory names
155+
:param catIds (int array) : get cats for given cat ids
156+
:return: ids (int array) : integer array of cat ids
157+
"""
158+
catNms = catNms if type(catNms) == list else [catNms]
159+
supNms = supNms if type(supNms) == list else [supNms]
160+
catIds = catIds if type(catIds) == list else [catIds]
161+
162+
if len(catNms) == len(supNms) == len(catIds) == 0:
163+
cats = self.dataset['categories']
164+
else:
165+
cats = self.dataset['categories']
166+
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
167+
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
168+
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
169+
ids = [cat['id'] for cat in cats]
170+
return ids
171+
172+
def getImgIds(self, imgIds=[], catIds=[]):
173+
'''
174+
Get img ids that satisfy given filter conditions.
175+
:param imgIds (int array) : get imgs for given ids
176+
:param catIds (int array) : get imgs with all given cats
177+
:return: ids (int array) : integer array of img ids
178+
'''
179+
imgIds = imgIds if type(imgIds) == list else [imgIds]
180+
catIds = catIds if type(catIds) == list else [catIds]
181+
182+
if len(imgIds) == len(catIds) == 0:
183+
ids = self.imgs.keys()
184+
else:
185+
ids = set(imgIds)
186+
for catId in catIds:
187+
if len(ids) == 0:
188+
ids = set(self.catToImgs[catId])
189+
else:
190+
ids &= set(self.catToImgs[catId])
191+
return list(ids)
192+
193+
def loadAnns(self, ids=[]):
194+
"""
195+
Load anns with the specified ids.
196+
:param ids (int array) : integer ids specifying anns
197+
:return: anns (object array) : loaded ann objects
198+
"""
199+
if type(ids) == list:
200+
return [self.anns[id] for id in ids]
201+
elif type(ids) == int:
202+
return [self.anns[ids]]
203+
204+
def loadCats(self, ids=[]):
205+
"""
206+
Load cats with the specified ids.
207+
:param ids (int array) : integer ids specifying cats
208+
:return: cats (object array) : loaded cat objects
209+
"""
210+
if type(ids) == list:
211+
return [self.cats[id] for id in ids]
212+
elif type(ids) == int:
213+
return [self.cats[ids]]
214+
215+
def loadImgs(self, ids=[]):
216+
"""
217+
Load anns with the specified ids.
218+
:param ids (int array) : integer ids specifying img
219+
:return: imgs (object array) : loaded img objects
220+
"""
221+
if type(ids) == list:
222+
return [self.imgs[id] for id in ids]
223+
elif type(ids) == int:
224+
return [self.imgs[ids]]
225+
226+
def showAnns(self, anns):
227+
"""
228+
Display the specified annotations.
229+
:param anns (array of object): annotations to display
230+
:return: None
231+
"""
232+
if len(anns) == 0:
233+
return 0
234+
if self.dataset['type'] == 'instances':
235+
ax = plt.gca()
236+
polygons = []
237+
color = []
238+
for ann in anns:
239+
c = np.random.random((1, 3)).tolist()[0]
240+
if type(ann['segmentation']) == list:
241+
# polygon
242+
for seg in ann['segmentation']:
243+
poly = np.array(seg).reshape((len(seg)/2, 2))
244+
polygons.append(Polygon(poly, True,alpha=0.4))
245+
color.append(c)
246+
else:
247+
# mask
248+
mask = COCO.decodeMask(ann['segmentation'])
249+
img = np.ones( (mask.shape[0], mask.shape[1], 3) )
250+
if ann['iscrowd'] == 1:
251+
color_mask = np.array([2.0,166.0,101.0])/255
252+
if ann['iscrowd'] == 0:
253+
color_mask = np.random.random((1, 3)).tolist()[0]
254+
for i in range(3):
255+
img[:,:,i] = color_mask[i]
256+
ax.imshow(np.dstack( (img, mask*0.5) ))
257+
p = PatchCollection(polygons, facecolors=color, edgecolors=(0,0,0,1), linewidths=3, alpha=0.4)
258+
ax.add_collection(p)
259+
if self.dataset['type'] == 'captions':
260+
for ann in anns:
261+
print ann['caption']
262+
263+
def loadRes(self, resFile):
264+
"""
265+
Load result file and return a result api object.
266+
:param resFile (str) : file name of result file
267+
:return: res (obj) : result api object
268+
"""
269+
res = COCO()
270+
res.dataset['images'] = [img for img in self.dataset['images']]
271+
# res.dataset['info'] = copy.deepcopy(self.dataset['info'])
272+
# res.dataset['type'] = copy.deepcopy(self.dataset['type'])
273+
# res.dataset['licenses'] = copy.deepcopy(self.dataset['licenses'])
274+
275+
print 'Loading and preparing results... '
276+
time_t = datetime.datetime.utcnow()
277+
anns = json.load(open(resFile))
278+
assert type(anns) == list, 'results in not an array of objects'
279+
annsImgIds = [ann['image_id'] for ann in anns]
280+
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
281+
'Results do not correspond to current coco set'
282+
if 'caption' in anns[0]:
283+
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
284+
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
285+
for id, ann in enumerate(anns):
286+
ann['id'] = id
287+
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
288+
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
289+
for id, ann in enumerate(anns):
290+
bb = ann['bbox']
291+
x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
292+
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
293+
ann['area'] = bb[2]*bb[3]
294+
ann['id'] = id
295+
ann['iscrowd'] = 0
296+
elif 'segmentation' in anns[0]:
297+
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
298+
for id, ann in enumerate(anns):
299+
ann['area']=sum(ann['segmentation']['counts'][2:-1:2])
300+
ann['bbox'] = []
301+
ann['id'] = id
302+
ann['iscrowd'] = 0
303+
print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
304+
305+
res.dataset['annotations'] = anns
306+
res.createIndex()
307+
return res
308+
309+
310+
@staticmethod
311+
def decodeMask(R):
312+
"""
313+
Decode binary mask M encoded via run-length encoding.
314+
:param R (object RLE) : run-length encoding of binary mask
315+
:return: M (bool 2D array) : decoded binary mask
316+
"""
317+
N = len(R['counts'])
318+
M = np.zeros( (R['size'][0]*R['size'][1], ))
319+
n = 0
320+
val = 1
321+
for pos in range(N):
322+
val = not val
323+
for c in range(R['counts'][pos]):
324+
R['counts'][pos]
325+
M[n] = val
326+
n += 1
327+
return M.reshape((R['size']), order='F')
328+
329+
@staticmethod
330+
def encodeMask(M):
331+
"""
332+
Encode binary mask M using run-length encoding.
333+
:param M (bool 2D array) : binary mask to encode
334+
:return: R (object RLE) : run-length encoding of binary mask
335+
"""
336+
[h, w] = M.shape
337+
M = M.flatten(order='F')
338+
N = len(M)
339+
counts_list = []
340+
pos = 0
341+
# counts
342+
counts_list.append(1)
343+
diffs = np.logical_xor(M[0:N-1], M[1:N])
344+
for diff in diffs:
345+
if diff:
346+
pos +=1
347+
counts_list.append(1)
348+
else:
349+
counts_list[pos] += 1
350+
# if array starts from 1. start with 0 counts for 0
351+
if M[0] == 1:
352+
counts_list = [0] + counts_list
353+
return {'size': [h, w],
354+
'counts': counts_list ,
355+
}
356+
357+
@staticmethod
358+
def segToMask( S, h, w ):
359+
"""
360+
Convert polygon segmentation to binary mask.
361+
:param S (float array) : polygon segmentation mask
362+
:param h (int) : target mask height
363+
:param w (int) : target mask width
364+
:return: M (bool 2D array) : binary mask
365+
"""
366+
M = np.zeros((h,w), dtype=np.bool)
367+
for s in S:
368+
N = len(s)
369+
rr, cc = polygon(np.array(s[1:N:2]), np.array(s[0:N:2])) # (y, x)
370+
M[rr, cc] = 1
371+
return M

0 commit comments

Comments
 (0)