diff --git a/dataload/augmentations.py b/dataload/augmentations.py new file mode 100644 index 0000000..8eecd60 --- /dev/null +++ b/dataload/augmentations.py @@ -0,0 +1,146 @@ +# coding=utf-8 +import cv2 +import random +import numpy as np + +class HSV(object): + def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5, p=0.5): + self.hgain = hgain + self.sgain = sgain + self.vgain = vgain + self.p = p + def __call__(self, img, bboxes): + if random.random() < self.p: + x = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains + img_hsv = (cv2.cvtColor(img, cv2.COLOR_BGR2HSV) * x).clip(None, 255).astype(np.uint8) + np.clip(img_hsv[:, :, 0], None, 179, out=img_hsv[:, :, 0]) # inplace hue clip (0 - 179 deg) + img = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed + return img, bboxes + +class RandomVerticalFilp(object): + def __init__(self, p=0.5): + self.p = p + def __call__(self, img, bboxes): + if random.random() < self.p: + h_img, _, _ = img.shape + img = img[::-1, :, :] #倒序::-1 + bboxes[:, [1, 3]] = h_img - bboxes[:, [3, 1]] # min,ymin,xmax,ymax,class + return img, bboxes + +class RandomHorizontalFilp(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, bboxes): + if random.random() < self.p: + _, w_img, _ = img.shape + # img = np.fliplr(img) + img = img[:, ::-1, :] + bboxes[:, [0, 2]] = w_img - bboxes[:, [2, 0]] + return img, bboxes + +class RandomCrop(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, bboxes): + if random.random() < self.p: + h_img, w_img, _ = img.shape + + max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1) + max_l_trans = max_bbox[0] + max_u_trans = max_bbox[1] + max_r_trans = w_img - max_bbox[2] + max_d_trans = h_img - max_bbox[3] + + crop_xmin = max(0, int(max_bbox[0] - random.uniform(0, max_l_trans))) + crop_ymin = max(0, int(max_bbox[1] - random.uniform(0, max_u_trans))) + crop_xmax = min(w_img, int(max_bbox[2] + random.uniform(0, max_r_trans)))# + crop_ymax = min(h_img, int(max_bbox[3] + random.uniform(0, max_d_trans)))# + + img = img[crop_ymin : crop_ymax, crop_xmin : crop_xmax] + + bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - crop_xmin + bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - crop_ymin + return img, bboxes + + +class RandomAffine(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, bboxes): + if random.random() < self.p: + h_img, w_img, _ = img.shape + max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1) + max_l_trans = max_bbox[0] + max_u_trans = max_bbox[1] + max_r_trans = w_img - max_bbox[2] + max_d_trans = h_img - max_bbox[3] + + tx = random.uniform(-(max_l_trans - 1), (max_r_trans - 1)) + ty = random.uniform(-(max_u_trans - 1), (max_d_trans - 1)) + + M = np.array([[1, 0, tx], [0, 1, ty]]) + img = cv2.warpAffine(img, M, (w_img, h_img)) + + bboxes[:, [0, 2]] = bboxes[:, [0, 2]] + tx + bboxes[:, [1, 3]] = bboxes[:, [1, 3]] + ty + return img, bboxes + + +class Resize(object): + + def __init__(self, target_shape, correct_box=True): + self.h_target, self.w_target = target_shape + self.correct_box = correct_box + + def __call__(self, img, bboxes): + h_org , w_org , _= img.shape + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) + + resize_ratio = min(1.0 * self.w_target / w_org, 1.0 * self.h_target / h_org) + resize_w = int(resize_ratio * w_org) + resize_h = int(resize_ratio * h_org) + image_resized = cv2.resize(img, (resize_w, resize_h)) + + image_paded = np.full((self.h_target, self.w_target, 3), 128.0) + dw = int((self.w_target - resize_w) / 2) + dh = int((self.h_target - resize_h) / 2) + image_paded[dh:resize_h + dh, dw:resize_w + dw, :] = image_resized + image = image_paded / 255.0 + + if self.correct_box: + bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * resize_ratio + dw + bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * resize_ratio + dh + return image, bboxes + return image + + +class Mixup(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img_org, bboxes_org, img_mix, bboxes_mix): + if random.random() > self.p: + lam = np.random.beta(1.5, 1.5) + img = lam * img_org + (1 - lam) * img_mix + bboxes_org = np.concatenate( + [bboxes_org, np.full((len(bboxes_org), 1), lam)], axis=1) + bboxes_mix = np.concatenate( + [bboxes_mix, np.full((len(bboxes_mix), 1), 1 - lam)], axis=1) + bboxes = np.concatenate([bboxes_org, bboxes_mix]) + + else: + img = img_org + bboxes = np.concatenate([bboxes_org, np.full((len(bboxes_org), 1), 1.0)], axis=1) + + return img, bboxes + + +class LabelSmooth(object): + def __init__(self, delta=0.01): + self.delta = delta + + def __call__(self, onehot, num_classes): + return onehot * (1 - self.delta) + self.delta * 1.0 / num_classes \ No newline at end of file diff --git a/dataload/cocodataset.py b/dataload/cocodataset.py new file mode 100644 index 0000000..1528537 --- /dev/null +++ b/dataload/cocodataset.py @@ -0,0 +1,116 @@ +import os +from torch.utils.data import Dataset +from pycocotools.coco import COCO + +import config.cfg_npmmrdet_dior as cfg +from utils.utils_coco import * + +class COCODataset(Dataset): + """ + COCO dataset class. + """ + def __init__(self, data_dir='COCO', json_file='instances_train2017.json', + name='train2017', img_size=416, + augmentation=None, min_size=1, debug=False): + """ + COCO dataset initialization. Annotation data are read into memory by COCO API. + Args: + model_type (str): model name specified in config file + data_dir (str): dataset root directory + json_file (str): COCO json file name ################## + name (str): COCO data name (e.g. 'train2017' or 'val2017') ########### + img_size (int): target image size after pre-processing + min_size (int): bounding boxes smaller than this are ignored + debug (bool): if True, only one data id is selected from the dataset + """ + self.data_dir = data_dir + self.json_file = json_file + self.coco = COCO(self.data_dir+'/json_gt/'+self.json_file) + self.ids = self.coco.getImgIds() + if debug: + self.ids = self.ids[1:2] + print("debug mode...", self.ids) + self.class_ids = sorted(self.coco.getCatIds()) + self.name = name + self.max_labels = cfg.MAX_LABEL######################### + self.img_size = img_size + self.min_size = min_size + self.lrflip = augmentation['LRFLIP'] + self.jitter = augmentation['JITTER'] + self.random_placing = augmentation['RANDOM_PLACING'] + self.hue = augmentation['HUE'] + self.saturation = augmentation['SATURATION'] + self.exposure = augmentation['EXPOSURE'] + self.random_distort = augmentation['RANDOM_DISTORT'] + + + + def __len__(self): + return len(self.ids) + + def __getitem__(self, index): + """ + One image / label pair for the given index is picked up \ + and pre-processed. + Args: + index (int): data index + Returns: + img (numpy.ndarray): pre-processed image + padded_labels (torch.Tensor): pre-processed label data. \ + The shape is :math:`[self.max_labels, 5]`. \ + each label consists of [class, xc, yc, w, h]: + class (float): class index. + xc, yc (float) : center of bbox whose values range from 0 to 1. + w, h (float) : size of bbox whose values range from 0 to 1. + info_img : tuple of h, w, nh, nw, dx, dy. + h, w (int): original shape of the image + nh, nw (int): shape of the resized image without padding + dx, dy (int): pad size + id_ (int): same as the input index. Used for evaluation. + """ + id_ = self.ids[index] + anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=None) + annotations = self.coco.loadAnns(anno_ids) + lrflip = False + if np.random.rand() > 0.5 and self.lrflip == True: + lrflip = True + + # load image and preprocess + img_file = os.path.join(self.data_dir, 'JPEGImages', + '{:0>5d}'.format(id_) + '.jpg') + img = cv2.imread(img_file) + imgshow = img + if self.json_file == 'instances_val5k.json' and img is None: + img_file = os.path.join(self.data_dir, 'train2017', + '{:012}'.format(id_) + '.jpg') + img = cv2.imread(img_file) + assert img is not None + + img, info_img = preprocess(img, self.img_size, jitter=self.jitter, + random_placing=self.random_placing) + + if self.random_distort: + img = random_distort(img, self.hue, self.saturation, self.exposure) + + img = np.transpose(img / 255., (2, 0, 1)) + + if lrflip: + img = np.flip(img, axis=2).copy() + + # load labels + labels = [] + for anno in annotations: + if anno['bbox'][2] > self.min_size and anno['bbox'][3] > self.min_size: + labels.append([]) + labels[-1].append(self.class_ids.index(anno['category_id'])) + labels[-1].extend(anno['bbox']) + + padded_labels = np.zeros((self.max_labels, 5)) + if len(labels) > 0: + labels = np.stack(labels) + labels = label2box(labels, info_img, self.img_size, lrflip) + padded_labels[range(len(labels))[:self.max_labels] + ] = labels[:self.max_labels] + padded_labels = torch.from_numpy(padded_labels) + + return img, padded_labels, info_img, id_, img_file diff --git a/dataload/datasets.py b/dataload/datasets.py new file mode 100644 index 0000000..389b16b --- /dev/null +++ b/dataload/datasets.py @@ -0,0 +1,180 @@ +# coding=utf-8 + +import os +import sys +sys.path.append("..") +sys.path.append("../utils") +import numpy as np +import cv2 +import random + +import torch +from torch.utils.data import Dataset + +import config.cfg_npmmrdet_dior as cfg +import dataload.augmentations as DataAug +import utils.utils_basic as tools + +class Construct_Dataset(Dataset): + def __init__(self, anno_file_type, img_size=448): + self.img_size = img_size # For Multi-training + self.classes = cfg.DATA["CLASSES"] + self.num_classes = len(self.classes) + self.class_to_id = dict(zip(self.classes, range(self.num_classes))) + self.__annotations = self.__load_annotations(anno_file_type) + + def __len__(self): + return len(self.__annotations) + + def __getitem__(self, item): + + img_org, bboxes_org = self.__parse_annotation(self.__annotations[item]) + img_org = img_org.transpose(2, 0, 1) # HWC->CHW + + item_mix = random.randint(0, len(self.__annotations) - 1) + img_mix, bboxes_mix = self.__parse_annotation(self.__annotations[item_mix]) + img_mix = img_mix.transpose(2, 0, 1) + + img, bboxes = DataAug.Mixup()(img_org, bboxes_org, img_mix, bboxes_mix) + #####bboxes xyxy + del img_org, bboxes_org, img_mix, bboxes_mix + + label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.__creat_label(bboxes) + + img = torch.from_numpy(img).float() + label_sbbox = torch.from_numpy(label_sbbox).float() + label_mbbox = torch.from_numpy(label_mbbox).float() + label_lbbox = torch.from_numpy(label_lbbox).float() + sbboxes = torch.from_numpy(sbboxes).float() + mbboxes = torch.from_numpy(mbboxes).float() + lbboxes = torch.from_numpy(lbboxes).float() + + return img, label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes + + def __load_annotations(self, anno_type): + assert anno_type in ['train', 'val', 'test'] + anno_path = os.path.join(cfg.PROJECT_PATH, 'data', anno_type + ".txt") + with open(anno_path, 'r') as f: + annotations = list(filter(lambda x: len(x) > 0, f.readlines())) + assert len(annotations) > 0, "No images found in {}".format(anno_path) + return annotations + + def __parse_annotation(self, annotation): + """ + Data augument. + :param annotation: Image' path and bboxes' coordinates, categories. + ex. [image_path xmin,ymin,xmax,ymax,class_ind xmin,ymin,xmax,ymax,class_ind ...] + :return: Return the enhanced image and bboxes. bbox'shape is [xmin, ymin, xmax, ymax, class_ind] + """ + anno = annotation.strip().split(' ') + + img_path = anno[0] + img = cv2.imread(img_path) # H*W*C and C=BGR + assert img is not None, 'File Not Found ' + img_path + bboxes = np.array([list(map(float, box.split(','))) for box in anno[1:]]) + + img, bboxes = DataAug.RandomVerticalFilp()(np.copy(img), np.copy(bboxes)) + img, bboxes = DataAug.RandomHorizontalFilp()(np.copy(img), np.copy(bboxes)) + img, bboxes = DataAug.HSV()(np.copy(img), np.copy(bboxes)) + img, bboxes = DataAug.RandomCrop()(np.copy(img), np.copy(bboxes)) + img, bboxes = DataAug.RandomAffine()(np.copy(img), np.copy(bboxes)) + img, bboxes = DataAug.Resize((self.img_size, self.img_size), True)(np.copy(img), np.copy(bboxes)) + + return img, bboxes + + def __creat_label(self, bboxes): + """ + Label assignment. For a single picture all GT box bboxes are assigned anchor. + 1、Select a bbox in order, convert its coordinates("xyxy") to "xywh"; and scale bbox' + xywh by the strides. + 2、Calculate the iou between the each detection layer'anchors and the bbox in turn, and select the largest + anchor to predict the bbox.If the ious of all detection layers are smaller than 0.3, select the largest + of all detection layers' anchors to predict the bbox. + + Note : + 1、The same GT may be assigned to multiple anchors. And the anchors may be on the same or different layer. + 2、The total number of bboxes may be more than it is, because the same GT may be assigned to multiple layers + of detection. + + """ + + anchors = np.array(cfg.MODEL["ANCHORS"]) + strides = np.array(cfg.MODEL["STRIDES"]) + train_output_size = self.img_size / strides + anchors_per_scale = cfg.MODEL["ANCHORS_PER_SCLAE"] + + label = [np.zeros((int(train_output_size[i]), int(train_output_size[i]), anchors_per_scale, 6 + self.num_classes))for i in range(3)] + for i in range(3): + label[i][..., 5] = 1.0 + + bboxes_xywh = [np.zeros((150, 4)) for _ in range(3)] # Darknet the max_num is 30 + bbox_count = np.zeros((3,)) + + for bbox in bboxes: + bbox_coor = bbox[:4] # 坐标xyxy + bbox_class_ind = int(bbox[4]) # 类型id + bbox_mix = bbox[5] # 混合bbox + + # onehot + one_hot = np.zeros(self.num_classes, dtype=np.float32) + one_hot[bbox_class_ind] = 1.0 + one_hot_smooth = DataAug.LabelSmooth()(one_hot, self.num_classes) + + # convert "xyxy" to "xywh" + bbox_xywh = np.concatenate([(bbox_coor[2:] + bbox_coor[:2]) * 0.5,bbox_coor[2:] - bbox_coor[:2]], axis=-1) + + bbox_xywh_scaled = 1.0 * bbox_xywh[np.newaxis, :] / strides[:, np.newaxis] # np.newaxis插入新维度 + #print("aaa", bbox_xywh[np.newaxis, :], strides[:, np.newaxis], bbox_xywh_scaled) + + iou = [] + exist_positive = False + for i in range(3): + anchors_xywh = np.zeros((anchors_per_scale, 4)) + anchors_xywh[:, 0:2] = np.floor(bbox_xywh_scaled[i, 0:2]).astype(np.int32) + 0.5 + anchors_xywh[:, 2:4] = anchors[i] + + iou_scale = tools.iou_xywh_numpy(bbox_xywh_scaled[i][np.newaxis, :], anchors_xywh) + iou.append(iou_scale) + iou_mask = iou_scale > 0.3 + + if np.any(iou_mask): + xind, yind = np.floor(bbox_xywh_scaled[i, 0:2]).astype(np.int32) + # Bug : 当多个bbox对应同一个anchor时,默认将该anchor分配给最后一个bbox + label[i][yind, xind, iou_mask, 0:4] = bbox_xywh + label[i][yind, xind, iou_mask, 4:5] = 1.0 + label[i][yind, xind, iou_mask, 5:6] = bbox_mix + label[i][yind, xind, iou_mask, 6:] = one_hot_smooth + + bbox_ind = int(bbox_count[i] % 150) # BUG : 150为一个先验值,内存消耗大 + bboxes_xywh[i][bbox_ind, :4] = bbox_xywh + bbox_count[i] += 1 + + exist_positive = True + + if not exist_positive: + best_anchor_ind = np.argmax(np.array(iou).reshape(-1), axis=-1) + best_detect = int(best_anchor_ind / anchors_per_scale) + best_anchor = int(best_anchor_ind % anchors_per_scale) + + xind, yind = np.floor(bbox_xywh_scaled[best_detect, 0:2]).astype(np.int32) + + label[best_detect][yind, xind, best_anchor, 0:4] = bbox_xywh + label[best_detect][yind, xind, best_anchor, 4:5] = 1.0 + label[best_detect][yind, xind, best_anchor, 5:6] = bbox_mix + label[best_detect][yind, xind, best_anchor, 6:] = one_hot_smooth + + bbox_ind = int(bbox_count[best_detect] % 150) #######最大检测数量 + bboxes_xywh[best_detect][bbox_ind, :4] = bbox_xywh + bbox_count[best_detect] += 1 + + label_sbbox, label_mbbox, label_lbbox = label + sbboxes, mbboxes, lbboxes = bboxes_xywh + + return label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes + +if __name__ == '__main__': + from torch.utils.data import DataLoader + train_dataset=Construct_Dataset(anno_file_type="train", img_size=cfg.TRAIN["TRAIN_IMG_SIZE"]) + train_dataloader = DataLoader(train_dataset,batch_size=cfg.TRAIN["BATCH_SIZE"], num_workers=cfg.TRAIN["NUMBER_WORKERS"],shuffle=False) + for i, (imgs, label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes) in enumerate(train_dataloader): + continue \ No newline at end of file diff --git a/modelR/head/dsc_head_hbb.py b/modelR/head/dsc_head_hbb.py new file mode 100644 index 0000000..ddbfbd8 --- /dev/null +++ b/modelR/head/dsc_head_hbb.py @@ -0,0 +1,40 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +import config.cfg_lodet as cfg + +class Ordinary_Head(nn.Module): + def __init__(self, nC, anchors, stride): + super(Ordinary_Head, self).__init__() + self.__anchors = anchors + self.__nA = len(anchors) + self.__nC = nC + self.__stride = stride + + def forward(self, p): + bs, nG = p.shape[0], p.shape[-1] + p = p.view(bs, self.__nA, 5 + self.__nC, nG, nG).permute(0, 3, 4, 1, 2) + p_de = self.__decode(p.clone()) + return (p, p_de) + + def __decode(self, p): + batch_size, output_size = p.shape[:2] + device = p.device + stride = self.__stride + anchors = (1.0 * self.__anchors).to(device) + conv_raw_dxdy = p[:, :, :, :, 0:2] + conv_raw_dwdh = p[:, :, :, :, 2:4] + conv_raw_conf = p[:, :, :, :, 4:5] + conv_raw_prob = p[:, :, :, :, 5:] + y = torch.arange(0, output_size).unsqueeze(1).repeat(1, output_size) + x = torch.arange(0, output_size).unsqueeze(0).repeat(output_size, 1) + grid_xy = torch.stack([x, y], dim=-1) + grid_xy = grid_xy.unsqueeze(0).unsqueeze(3).repeat(batch_size, 1, 1, 3, 1).float().to(device) + pred_xy = (torch.sigmoid(conv_raw_dxdy) + grid_xy) * stride + pred_wh = (torch.exp(conv_raw_dwdh) * anchors) * stride + pred_xywh = torch.cat([pred_xy, pred_wh], dim=-1) + pred_conf = torch.sigmoid(conv_raw_conf) + pred_prob = torch.sigmoid(conv_raw_prob) + pred_bbox = torch.cat([pred_xywh, pred_conf, pred_prob], dim=-1) + + return pred_bbox.view(-1, 5 + self.__nC) if not self.training else pred_bbox \ No newline at end of file