|
| 1 | +import argparse |
| 2 | + |
| 3 | +import cv2 as cv |
| 4 | +import glog as log |
| 5 | +import numpy as np |
| 6 | +from openvino.inference_engine import IECore |
| 7 | + |
| 8 | +from demo_tools import load_ie_model |
| 9 | +from torchdet3d.utils import draw_kp |
| 10 | + |
| 11 | + |
| 12 | +OBJECTRON_CLASSES = ('bike', 'book', 'bottle', 'cereal_box', 'camera', 'chair', 'cup', 'laptop', 'shoe') |
| 13 | + |
| 14 | +class Detector: |
| 15 | + """Wrapper class for object detector""" |
| 16 | + def __init__(self, ie, model_path, conf=.6, device='CPU', ext_path=''): |
| 17 | + self.net = load_ie_model(ie, model_path, device, None, ext_path) |
| 18 | + self.confidence = conf |
| 19 | + self.expand_ratio = (1., 1.) |
| 20 | + |
| 21 | + def get_detections(self, frame): |
| 22 | + """Returns all detections on frame""" |
| 23 | + out = self.net.forward(frame) |
| 24 | + detections = self.__decode_detections(out, frame.shape) |
| 25 | + return detections |
| 26 | + |
| 27 | + def __decode_detections(self, out, frame_shape): |
| 28 | + """Decodes raw SSD output""" |
| 29 | + detections = [] |
| 30 | + |
| 31 | + for detection in out[0, 0]: |
| 32 | + label = detection[1] |
| 33 | + confidence = detection[2] |
| 34 | + if confidence > self.confidence: |
| 35 | + left = int(max(detection[3], 0) * frame_shape[1]) |
| 36 | + top = int(max(detection[4], 0) * frame_shape[0]) |
| 37 | + right = int(max(detection[5], 0) * frame_shape[1]) |
| 38 | + bottom = int(max(detection[6], 0) * frame_shape[0]) |
| 39 | + if self.expand_ratio != (1., 1.): |
| 40 | + w = (right - left) |
| 41 | + h = (bottom - top) |
| 42 | + dw = w * (self.expand_ratio[0] - 1.) / 2 |
| 43 | + dh = h * (self.expand_ratio[1] - 1.) / 2 |
| 44 | + left = max(int(left - dw), 0) |
| 45 | + right = int(right + dw) |
| 46 | + top = max(int(top - dh), 0) |
| 47 | + bottom = int(bottom + dh) |
| 48 | + |
| 49 | + detections.append(((left, top, right, bottom), confidence, label)) |
| 50 | + |
| 51 | + if len(detections) > 1: |
| 52 | + detections.sort(key=lambda x: x[1], reverse=True) |
| 53 | + return detections |
| 54 | + |
| 55 | + |
| 56 | +class Regressor: |
| 57 | + """Wrapper class for regression model""" |
| 58 | + def __init__(self, ie, model_path, device='CPU', ext_path=''): |
| 59 | + self.net = load_ie_model(ie, model_path, device, None, ext_path) |
| 60 | + |
| 61 | + def get_detections(self, frame, detections): |
| 62 | + """Returns all detections on frame""" |
| 63 | + outputs = [] |
| 64 | + for rect in detections: |
| 65 | + cropped_img = self.crop(frame, rect[0]) |
| 66 | + out = self.net.forward(cropped_img) |
| 67 | + out = self.__decode_detections(out, rect) |
| 68 | + outputs.append(out) |
| 69 | + return outputs |
| 70 | + |
| 71 | + def __decode_detections(self, out, rect): |
| 72 | + """Decodes raw regression model output""" |
| 73 | + label = int(rect[2]) |
| 74 | + kp = out[label] |
| 75 | + kp = self.transform_kp(kp[0], rect[0]) |
| 76 | + |
| 77 | + return (kp, label) |
| 78 | + |
| 79 | + @staticmethod |
| 80 | + def transform_kp(kp: np.array, crop_cords: tuple): |
| 81 | + x0,y0,x1,y1 = crop_cords |
| 82 | + crop_shape = (x1-x0,y1-y0) |
| 83 | + kp[:,0] = kp[:,0]*crop_shape[0] |
| 84 | + kp[:,1] = kp[:,1]*crop_shape[1] |
| 85 | + kp[:,0] += x0 |
| 86 | + kp[:,1] += y0 |
| 87 | + return kp |
| 88 | + |
| 89 | + @staticmethod |
| 90 | + def crop(frame, rect): |
| 91 | + x0, y0, x1, y1 = rect |
| 92 | + crop = frame[y0:y1, x0:x1] |
| 93 | + return crop |
| 94 | + |
| 95 | +def draw_detections(frame, reg_detections, det_detections, reg_only=True): |
| 96 | + """Draws detections and labels""" |
| 97 | + for det_out, reg_out in zip(det_detections, reg_detections): |
| 98 | + left, top, right, bottom = det_out[0] |
| 99 | + kp = reg_out[0] |
| 100 | + label = reg_out[1] |
| 101 | + label = OBJECTRON_CLASSES[label] |
| 102 | + if not reg_only: |
| 103 | + cv.rectangle(frame, (left, top), (right, bottom), (0, 255, 0), thickness=2) |
| 104 | + |
| 105 | + frame = draw_kp(frame, kp, None, RGB=False, normalized=False) |
| 106 | + label_size, base_line = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 1, 1) |
| 107 | + top = max(top, label_size[1]) |
| 108 | + cv.rectangle(frame, (left, top - label_size[1]), (left + label_size[0], top + base_line), |
| 109 | + (255, 255, 255), cv.FILLED) |
| 110 | + cv.putText(frame, label, (left, top), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0)) |
| 111 | + |
| 112 | + return frame |
| 113 | + |
| 114 | +def run(params, capture, detector, regressor, write_video=False, resolution = (1280, 720)): |
| 115 | + """Starts the 3D object detection demo""" |
| 116 | + fourcc = cv.VideoWriter_fourcc(*'MP4V') |
| 117 | + fps = 24 |
| 118 | + if write_video: |
| 119 | + writer_video = cv.VideoWriter('output_video_demo.mp4', fourcc, fps, resolution) |
| 120 | + win_name = '3D-object-detection' |
| 121 | + while cv.waitKey(1) != 27: |
| 122 | + has_frame, frame = capture.read() |
| 123 | + frame = cv.resize(frame, resolution) |
| 124 | + if not has_frame: |
| 125 | + return |
| 126 | + detections = detector.get_detections(frame) |
| 127 | + outputs = regressor.get_detections(frame, detections) |
| 128 | + |
| 129 | + frame = draw_detections(frame, outputs, detections, reg_only=False) |
| 130 | + cv.imshow(win_name, frame) |
| 131 | + if write_video: |
| 132 | + writer_video.write(cv.resize(frame, resolution)) |
| 133 | + writer_video.release() |
| 134 | + capture.release() |
| 135 | + cv.destroyAllWindows() |
| 136 | + |
| 137 | +def main(): |
| 138 | + """Prepares data for the 3d object detection demo""" |
| 139 | + |
| 140 | + parser = argparse.ArgumentParser(description='3d object detection live demo script') |
| 141 | + parser.add_argument('--video', type=str, default=None, help='Input video') |
| 142 | + parser.add_argument('--cam_id', type=int, default=-1, help='Input cam') |
| 143 | + parser.add_argument('--resolution', type=int, nargs='+', help='capture resolution') |
| 144 | + parser.add_argument('--config', type=str, default=None, required=False, |
| 145 | + help='Configuration file') |
| 146 | + parser.add_argument('--od_model', type=str, required=True) |
| 147 | + parser.add_argument('--reg_model', type=str, required=True) |
| 148 | + parser.add_argument('--det_tresh', type=float, required=False, default=0.6) |
| 149 | + parser.add_argument('--device', type=str, default='CPU') |
| 150 | + parser.add_argument('-l', '--cpu_extension', |
| 151 | + help='MKLDNN (CPU)-targeted custom layers.Absolute path to a shared library with the kernels ' |
| 152 | + 'impl.', type=str, default=None) |
| 153 | + parser.add_argument('--write_video', type=bool, default=False, |
| 154 | + help='if you set this arg to True, the video of the demo will be recoreded') |
| 155 | + args = parser.parse_args() |
| 156 | + |
| 157 | + if args.cam_id >= 0: |
| 158 | + log.info('Reading from cam {}'.format(args.cam_id)) |
| 159 | + cap = cv.VideoCapture(args.cam_id) |
| 160 | + cap.set(cv.CAP_PROP_FRAME_WIDTH, args.resolution[0]) |
| 161 | + cap.set(cv.CAP_PROP_FRAME_HEIGHT, args.resolution[1]) |
| 162 | + cap.set(cv.CAP_PROP_FOURCC, cv.VideoWriter_fourcc(*'MJPG')) |
| 163 | + else: |
| 164 | + assert args.video, "No video input was given" |
| 165 | + log.info('Reading from {}'.format(args.video)) |
| 166 | + cap = cv.VideoCapture(args.video) |
| 167 | + cap.set(cv.CAP_PROP_FOURCC, cv.VideoWriter_fourcc(*'MJPG')) |
| 168 | + assert cap.isOpened() |
| 169 | + ie = IECore() |
| 170 | + object_detector = Detector(ie, args.od_model, args.det_tresh, args.device, args.cpu_extension) |
| 171 | + regressor = Regressor(ie, args.reg_model, args.device, args.cpu_extension) |
| 172 | + # running demo |
| 173 | + run(args, cap, object_detector, regressor, args.write_video, tuple(args.resolution)) |
| 174 | + |
| 175 | +if __name__ == '__main__': |
| 176 | + main() |
0 commit comments