@@ -71,10 +71,17 @@ def gen_cam(image, mask):
71
71
return norm_image (cam ), heatmap
72
72
73
73
74
- def save_image (image_dicts , input_image_name , network = 'retinanet' , output_dir = './results' ):
74
+ def save_image (image_dicts , input_image_name , layer_name , network = 'retinanet' , output_dir = './results' ):
75
75
prefix = os .path .splitext (input_image_name )[0 ]
76
76
for key , image in image_dicts .items ():
77
- io .imsave (os .path .join (output_dir , '{}-{}-{}.jpg' .format (prefix , network , key )), image )
77
+ if key == 'predict_box' :
78
+ io .imsave (os .path .join (output_dir ,
79
+ '{}-{}-{}.jpg' .format (prefix , network , key )),
80
+ image )
81
+ else :
82
+ io .imsave (os .path .join (output_dir ,
83
+ '{}-{}-{}-{}.jpg' .format (prefix , network , layer_name , key )),
84
+ image )
78
85
79
86
80
87
def get_parser ():
@@ -104,7 +111,7 @@ def get_parser():
104
111
default = [],
105
112
nargs = argparse .REMAINDER ,
106
113
)
107
- parser .add_argument ('--layer-name' , type = str , default = 'head.cls_subnet.0 ' ,
114
+ parser .add_argument ('--layer-name' , type = str , default = 'head.cls_subnet.2 ' ,
108
115
help = '使用哪层特征去生成CAM' )
109
116
return parser
110
117
@@ -150,6 +157,7 @@ def main(args):
150
157
# Grad-CAM++
151
158
grad_cam_plus_plus = GradCamPlusPlus (model , layer_name )
152
159
mask_plus_plus = grad_cam_plus_plus (inputs ) # cam mask
160
+
153
161
_ , image_dict ['heatmap++' ] = gen_cam (img [y1 :y2 , x1 :x2 ], mask_plus_plus [y1 :y2 , x1 :x2 ])
154
162
grad_cam_plus_plus .remove_handlers ()
155
163
@@ -161,14 +169,15 @@ def main(args):
161
169
162
170
print ("label:{}" .format (label ))
163
171
164
- save_image (image_dict , os .path .basename (path ))
172
+ save_image (image_dict , os .path .basename (path ), args . layer_name )
165
173
166
174
167
175
if __name__ == "__main__" :
168
176
"""
169
177
Usage:export KMP_DUPLICATE_LIB_OK=TRUE
170
178
python detection/demo_retinanet.py --config-file detection/retinanet_R_50_FPN_3x.yaml \
171
179
--input ./examples/pic1.jpg \
180
+ --layer-name head.cls_subnet.7 \
172
181
--opts MODEL.WEIGHTS /Users/yizuotian/pretrained_model/model_final_4cafe0.pkl MODEL.DEVICE cpu
173
182
"""
174
183
mp .set_start_method ("spawn" , force = True )
0 commit comments