|
| 1 | +class COCOAnnotationTransform(object): |
| 2 | + """Transforms a VOC annotation into a Tensor of bbox coords and label index |
| 3 | + Initilized with a dictionary lookup of classnames to indexes |
| 4 | +
|
| 5 | + Arguments: |
| 6 | + class_to_ind (dict, optional): dictionary lookup of classnames -> indexes |
| 7 | + (default: alphabetic indexing of VOC's 20 classes) |
| 8 | + keep_difficult (bool, optional): keep difficult instances or not |
| 9 | + (default: False) |
| 10 | + height (int): height |
| 11 | + width (int): width |
| 12 | + """ |
| 13 | + |
| 14 | +# def __init__(self) |
| 15 | + |
| 16 | + def __call__(self, target, width, height): |
| 17 | + """ |
| 18 | + Arguments: |
| 19 | + target (annotation) : the target annotation to be made usable |
| 20 | + will be an ET.Element |
| 21 | + Returns: |
| 22 | + a list containing lists of bounding boxes [bbox coords, class name] |
| 23 | + """ |
| 24 | + scale = np.array([width, height, width, height]) |
| 25 | + res = [] |
| 26 | + for obj in target: |
| 27 | + if 'bbox' in obj: |
| 28 | + bbox = obj['bbox'] |
| 29 | + bbox[2] += bbox[0] |
| 30 | + bbox[3] += bbox[1] |
| 31 | + label_idx = obj['category_id'] |
| 32 | + final_box = list(np.array(bbox)/scale) |
| 33 | + final_box.append(label_idx) |
| 34 | + res += [final_box] # [xmin, ymin, xmax, ymax, label_ind] |
| 35 | + return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] |
| 36 | + |
| 37 | + |
| 38 | +class COCODetection(data.Dataset): |
| 39 | + """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. |
| 40 | + Args: |
| 41 | + root (string): Root directory where images are downloaded to. |
| 42 | + annFile (string): Path to json annotation file. |
| 43 | + transform (callable, optional): A function/transform that takes in an PIL image |
| 44 | + and returns a transformed version. E.g, ``transforms.ToTensor`` |
| 45 | + target_transform (callable, optional): A function/transform that takes in the |
| 46 | + target and transforms it. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__(self, root, annFile, transform=None, target_transform=None): |
| 50 | + from pycocotools.coco import COCO |
| 51 | + self.root = root |
| 52 | + self.coco = COCO(annFile) |
| 53 | + self.ids = list(self.coco.imgs.keys()) |
| 54 | + self.transform = transform |
| 55 | + self.target_transform = target_transform |
| 56 | + |
| 57 | + def __getitem__(self, index): |
| 58 | + """ |
| 59 | + Args: |
| 60 | + index (int): Index |
| 61 | + Returns: |
| 62 | + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. |
| 63 | + """ |
| 64 | + im, gt, h, w = self.pull_item(index) |
| 65 | + return im, gt |
| 66 | + |
| 67 | + def __len__(self): |
| 68 | + return len(self.ids) |
| 69 | + |
| 70 | + def pull_item(self, index): |
| 71 | + """ |
| 72 | + Args: |
| 73 | + index (int): Index |
| 74 | + Returns: |
| 75 | + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. |
| 76 | + """ |
| 77 | + coco = self.coco |
| 78 | + img_id = self.ids[index] |
| 79 | + ann_ids = coco.getAnnIds(imgIds=img_id) |
| 80 | + target = coco.loadAnns(ann_ids) |
| 81 | + path = coco.loadImgs(img_id)[0]['file_name'] |
| 82 | + img = cv2.imread(os.path.join(self.root, path)) |
| 83 | + height, width, channels = img.shape |
| 84 | + if self.target_transform is not None: |
| 85 | + target = self.target_transform(target, width, height) |
| 86 | + if self.transform is not None: |
| 87 | + target = np.array(target) |
| 88 | + img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) |
| 89 | + # to rgb |
| 90 | + img = img[:, :, (2, 1, 0)] |
| 91 | + # img = img.transpose(2, 0, 1) |
| 92 | + target = np.hstack((boxes, np.expand_dims(labels, axis=1))) |
| 93 | + return torch.from_numpy(img).permute(2, 0, 1), target, height, width |
| 94 | + |
| 95 | + def __repr__(self): |
| 96 | + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' |
| 97 | + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) |
| 98 | + fmt_str += ' Root Location: {}\n'.format(self.root) |
| 99 | + tmp = ' Transforms (if any): ' |
| 100 | + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| 101 | + tmp = ' Target Transforms (if any): ' |
| 102 | + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| 103 | + return fmt_str |
0 commit comments