Skip to content

Commit 68bb076

Browse files
committed
增加目标检测
1 parent b3ee86c commit 68bb076

File tree

4 files changed

+343
-0
lines changed

4 files changed

+343
-0
lines changed

detection/Base-RCNN-C4.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
MODEL:
2+
META_ARCHITECTURE: "GeneralizedRCNN"
3+
RPN:
4+
PRE_NMS_TOPK_TEST: 6000
5+
POST_NMS_TOPK_TEST: 1000
6+
ROI_HEADS:
7+
NAME: "Res5ROIHeads"
8+
DATASETS:
9+
TRAIN: ("coco_2017_train",)
10+
TEST: ("coco_2017_val",)
11+
SOLVER:
12+
IMS_PER_BATCH: 16
13+
BASE_LR: 0.02
14+
STEPS: (60000, 80000)
15+
MAX_ITER: 90000
16+
INPUT:
17+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
18+
VERSION: 2

detection/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@File : __init__.py.py
4+
@Time : 2020/2/24 下午10:08
5+
@Author : yizuotian
6+
@Description :
7+
"""

detection/demo.py

+300
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
import argparse
3+
import multiprocessing as mp
4+
import os
5+
6+
import cv2
7+
import detectron2.data.transforms as T
8+
import numpy as np
9+
import torch
10+
from detectron2.checkpoint import DetectionCheckpointer
11+
from detectron2.config import get_cfg
12+
from detectron2.data import MetadataCatalog
13+
from detectron2.data.detection_utils import read_image
14+
from detectron2.modeling import build_model
15+
from detectron2.utils.logger import setup_logger
16+
from skimage import io
17+
from torch import nn
18+
19+
# constants
20+
WINDOW_NAME = "COCO detections"
21+
22+
23+
def setup_cfg(args):
24+
# load config from file and command-line arguments
25+
cfg = get_cfg()
26+
cfg.merge_from_file(args.config_file)
27+
cfg.merge_from_list(args.opts)
28+
# Set score_threshold for builtin models
29+
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
30+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
31+
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
32+
cfg.freeze()
33+
return cfg
34+
35+
36+
def get_last_conv_name(net):
37+
"""
38+
获取网络的最后一个卷积层的名字
39+
:param net:
40+
:return:
41+
"""
42+
layer_name = None
43+
for name, m in net.named_modules():
44+
if isinstance(m, nn.Conv2d):
45+
layer_name = name
46+
return layer_name
47+
48+
49+
class GradCAM(object):
50+
"""
51+
1: 网络不更新梯度,输入需要梯度更新
52+
2: 使用目标类别的得分做反向传播
53+
"""
54+
55+
def __init__(self, net, layer_name):
56+
self.net = net
57+
self.layer_name = layer_name
58+
self.feature = None
59+
self.gradient = None
60+
self.net.eval()
61+
self.handlers = []
62+
self._register_hook()
63+
64+
def _get_features_hook(self, module, input, output):
65+
self.feature = output
66+
print("feature shape:{}".format(output.size()))
67+
68+
def _get_grads_hook(self, module, input_grad, output_grad):
69+
"""
70+
71+
:param input_grad: tuple, input_grad[0]: None
72+
input_grad[1]: weight
73+
input_grad[2]: bias
74+
:param output_grad:tuple,长度为1
75+
:return:
76+
"""
77+
self.gradient = output_grad[0]
78+
79+
def _register_hook(self):
80+
for (name, module) in self.net.named_modules():
81+
if name == self.layer_name:
82+
self.handlers.append(module.register_forward_hook(self._get_features_hook))
83+
self.handlers.append(module.register_backward_hook(self._get_grads_hook))
84+
85+
def remove_handlers(self):
86+
for handle in self.handlers:
87+
handle.remove()
88+
89+
def __call__(self, inputs, index=0):
90+
"""
91+
92+
:param inputs: {"image": [C,H,W], "height": height, "width": width}
93+
:param index: 第几个边框
94+
:return:
95+
"""
96+
self.net.zero_grad()
97+
output = self.net.inference([inputs])
98+
print(output)
99+
score = output[0]['instances'].scores[index]
100+
proposal_idx = output[0]['instances'].indices[index] # box来自第几个proposal
101+
score.backward()
102+
103+
gradient = self.gradient[proposal_idx].cpu().data.numpy() # [C,H,W]
104+
weight = np.mean(gradient, axis=(1, 2)) # [C]
105+
106+
feature = self.feature[0].cpu().data.numpy() # [C,H,W]
107+
108+
cam = feature * weight[:, np.newaxis, np.newaxis] # [C,H,W]
109+
cam = np.sum(cam, axis=0) # [H,W]
110+
cam = np.maximum(cam, 0) # ReLU
111+
112+
# 数值归一化
113+
cam -= np.min(cam)
114+
cam /= np.max(cam)
115+
# resize to 224*224
116+
box = output[0]['instances'].pred_boxes.tensor[index].detach().numpy().astype(np.int32)
117+
x1, y1, x2, y2 = box
118+
cam = cv2.resize(cam, (x2 - x1, y2 - y1))
119+
120+
class_id = output[0]['instances'].pred_classes[index].detach().numpy()
121+
return cam, box, class_id
122+
123+
124+
class GuidedBackPropagation(object):
125+
126+
def __init__(self, net):
127+
self.net = net
128+
for (name, module) in self.net.named_modules():
129+
if isinstance(module, nn.ReLU):
130+
module.register_backward_hook(self.backward_hook)
131+
self.net.eval()
132+
133+
@classmethod
134+
def backward_hook(cls, module, grad_in, grad_out):
135+
"""
136+
137+
:param module:
138+
:param grad_in: tuple,长度为1
139+
:param grad_out: tuple,长度为1
140+
:return: tuple(new_grad_in,)
141+
"""
142+
return torch.clamp(grad_in[0], min=0.0),
143+
144+
def __call__(self, inputs, index=0):
145+
"""
146+
147+
:param inputs: {"image": [C,H,W], "height": height, "width": width}
148+
:param index: 第几个边框
149+
:return:
150+
"""
151+
self.net.zero_grad()
152+
output = self.net.inference([inputs])
153+
score = output[0]['instances'].scores[index]
154+
score.backward()
155+
156+
return inputs['image'].grad[0] # [3,H,W]
157+
158+
159+
def norm_image(image):
160+
"""
161+
标准化图像
162+
:param image: [H,W,C]
163+
:return:
164+
"""
165+
image = image.copy()
166+
image -= np.max(np.min(image), 0)
167+
image /= np.max(image)
168+
image *= 255.
169+
return np.uint8(image)
170+
171+
172+
def gen_cam(image, mask):
173+
"""
174+
生成CAM图
175+
:param image: [H,W,C],原始图像
176+
:param mask: [H,W],范围0~1
177+
:return: tuple(cam,heatmap)
178+
"""
179+
# mask转为heatmap
180+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
181+
heatmap = np.float32(heatmap) / 255
182+
heatmap = heatmap[..., ::-1] # gbr to rgb
183+
184+
# 合并heatmap到原始图像
185+
cam = heatmap + np.float32(image)
186+
return norm_image(cam), heatmap
187+
188+
189+
def gen_gb(grad):
190+
"""
191+
生guided back propagation 输入图像的梯度
192+
:param grad: tensor,[3,H,W]
193+
:return:
194+
"""
195+
# 标准化
196+
grad = grad.data.numpy()
197+
gb = np.transpose(grad, (1, 2, 0))
198+
return gb
199+
200+
201+
def save_image(image_dicts, input_image_name, network='frcnn', output_dir='./results'):
202+
prefix = os.path.splitext(input_image_name)[0]
203+
for key, image in image_dicts.items():
204+
io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, network, key)), image)
205+
206+
207+
def get_parser():
208+
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin models")
209+
parser.add_argument(
210+
"--config-file",
211+
default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
212+
metavar="FILE",
213+
help="path to config file",
214+
)
215+
parser.add_argument("--input", help="A list of space separated input images")
216+
parser.add_argument(
217+
"--output",
218+
help="A file or directory to save output visualizations. "
219+
"If not given, will show output in an OpenCV window.",
220+
)
221+
222+
parser.add_argument(
223+
"--confidence-threshold",
224+
type=float,
225+
default=0.5,
226+
help="Minimum score for instance predictions to be shown",
227+
)
228+
parser.add_argument(
229+
"--opts",
230+
help="Modify config options using the command-line 'KEY VALUE' pairs",
231+
default=[],
232+
nargs=argparse.REMAINDER,
233+
)
234+
return parser
235+
236+
237+
if __name__ == "__main__":
238+
"""
239+
Usage:export KMP_DUPLICATE_LIB_OK=TRUE
240+
python detection/demo.py --config-file detection/faster_rcnn_R_50_C4.yaml \
241+
--input ./examples/pic1.jpg \
242+
--opts MODEL.WEIGHTS /Users/yizuotian/pretrained_model/model_final_b1acc2.pkl MODEL.DEVICE cpu
243+
"""
244+
mp.set_start_method("spawn", force=True)
245+
args = get_parser().parse_args()
246+
setup_logger(name="fvcore")
247+
logger = setup_logger()
248+
logger.info("Arguments: " + str(args))
249+
250+
cfg = setup_cfg(args)
251+
print(cfg)
252+
# 构建模型
253+
model = build_model(cfg)
254+
# 加载权重
255+
checkpointer = DetectionCheckpointer(model)
256+
checkpointer.load(cfg.MODEL.WEIGHTS)
257+
258+
# 加载图像
259+
path = os.path.expanduser(args.input)
260+
original_image = read_image(path, format="BGR")
261+
height, width = original_image.shape[:2]
262+
transform_gen = T.ResizeShortestEdge(
263+
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
264+
)
265+
image = transform_gen.get_transform(original_image).apply_image(original_image)
266+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)).requires_grad_(True)
267+
268+
inputs = {"image": image, "height": height, "width": width}
269+
270+
# Grad-CAM
271+
layer_name = get_last_conv_name(model)
272+
grad_cam = GradCAM(model, layer_name)
273+
mask, box, class_id = grad_cam(inputs) # cam mask
274+
grad_cam.remove_handlers()
275+
#
276+
image_dict = {}
277+
img = original_image[..., ::-1]
278+
x1, y1, x2, y2 = box
279+
image_dict['predict_box'] = img[y1:y2, x1:x2]
280+
image_cam, image_dict['heatmap'] = gen_cam(img[y1:y2, x1:x2], mask)
281+
282+
# 获取类别名称
283+
meta = MetadataCatalog.get(
284+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
285+
)
286+
label = meta.thing_classes[class_id]
287+
288+
print("label:{}".format(label))
289+
# GuidedBackPropagation
290+
# gbp = GuidedBackPropagation(model)
291+
# inputs['image'].grad.zero_() # 梯度置零
292+
# grad = gbp(inputs)
293+
# print("grad.shape:{}".format(grad.shape))
294+
# gb = gen_gb(grad)
295+
# image_dict['gb'] = gb
296+
# 生成Guided Grad-CAM
297+
# cam_gb = gb * mask[..., np.newaxis]
298+
# image_dict['cam_gb'] = norm_image(cam_gb)
299+
300+
save_image(image_dict, os.path.basename(path))

detection/faster_rcnn_R_50_C4.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_BASE_: "./Base-RCNN-C4.yaml"
2+
MODEL:
3+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
4+
MASK_ON: False
5+
RESNETS:
6+
DEPTH: 50
7+
ROI_HEADS:
8+
NUM_CLASSES: 20
9+
INPUT:
10+
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
11+
MIN_SIZE_TEST: 800
12+
DATASETS:
13+
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
14+
TEST: ('voc_2007_test',)
15+
SOLVER:
16+
STEPS: (12000, 16000)
17+
MAX_ITER: 18000 # 17.4 epochs
18+
WARMUP_ITERS: 100

0 commit comments

Comments
 (0)