Skip to content

Commit 9c4ddc2

Browse files
committed
fix bug:参数顺序
1 parent 5eb54b4 commit 9c4ddc2

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

detection/demo_retinanet.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,17 @@ def gen_cam(image, mask):
7171
return norm_image(cam), heatmap
7272

7373

74-
def save_image(image_dicts, input_image_name, network='retinanet', output_dir='./results'):
74+
def save_image(image_dicts, input_image_name, layer_name, network='retinanet', output_dir='./results'):
7575
prefix = os.path.splitext(input_image_name)[0]
7676
for key, image in image_dicts.items():
77-
io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, network, key)), image)
77+
if key == 'predict_box':
78+
io.imsave(os.path.join(output_dir,
79+
'{}-{}-{}.jpg'.format(prefix, network, key)),
80+
image)
81+
else:
82+
io.imsave(os.path.join(output_dir,
83+
'{}-{}-{}-{}.jpg'.format(prefix, network, layer_name, key)),
84+
image)
7885

7986

8087
def get_parser():
@@ -104,7 +111,7 @@ def get_parser():
104111
default=[],
105112
nargs=argparse.REMAINDER,
106113
)
107-
parser.add_argument('--layer-name', type=str, default='head.cls_subnet.0',
114+
parser.add_argument('--layer-name', type=str, default='head.cls_subnet.2',
108115
help='使用哪层特征去生成CAM')
109116
return parser
110117

@@ -150,6 +157,7 @@ def main(args):
150157
# Grad-CAM++
151158
grad_cam_plus_plus = GradCamPlusPlus(model, layer_name)
152159
mask_plus_plus = grad_cam_plus_plus(inputs) # cam mask
160+
153161
_, image_dict['heatmap++'] = gen_cam(img[y1:y2, x1:x2], mask_plus_plus[y1:y2, x1:x2])
154162
grad_cam_plus_plus.remove_handlers()
155163

@@ -161,14 +169,15 @@ def main(args):
161169

162170
print("label:{}".format(label))
163171

164-
save_image(image_dict, os.path.basename(path))
172+
save_image(image_dict, os.path.basename(path), args.layer_name)
165173

166174

167175
if __name__ == "__main__":
168176
"""
169177
Usage:export KMP_DUPLICATE_LIB_OK=TRUE
170178
python detection/demo_retinanet.py --config-file detection/retinanet_R_50_FPN_3x.yaml \
171179
--input ./examples/pic1.jpg \
180+
--layer-name head.cls_subnet.7 \
172181
--opts MODEL.WEIGHTS /Users/yizuotian/pretrained_model/model_final_4cafe0.pkl MODEL.DEVICE cpu
173182
"""
174183
mp.set_start_method("spawn", force=True)

0 commit comments

Comments
 (0)