Skip to content

Commit aae33dc

Browse files
committed
增加Grad-CAM++
1 parent 9fb11ea commit aae33dc

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

interpretability/grad_cam.py

+42
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,45 @@ def __call__(self, inputs, index):
7878
# resize to 224*224
7979
cam = cv2.resize(cam, (224, 224))
8080
return cam
81+
82+
83+
class GradCamPlusPlus(GradCAM):
84+
def __init__(self, net, layer_name):
85+
super(GradCamPlusPlus, self).__init__(net, layer_name)
86+
87+
def __call__(self, inputs, index):
88+
"""
89+
90+
:param inputs: [1,3,H,W]
91+
:param index: class id
92+
:return:
93+
"""
94+
self.net.zero_grad()
95+
output = self.net(inputs) # [1,num_classes]
96+
if index is None:
97+
index = np.argmax(output.cpu().data.numpy())
98+
target = output[0][index]
99+
target.backward()
100+
101+
gradient = self.gradient[0].cpu().data.numpy() # [C,H,W]
102+
gradient = np.maximum(gradient, 0.) # ReLU
103+
indicate = np.where(gradient > 0, 1., 0.) # 示性函数
104+
norm_factor = np.sum(gradient, axis=(1, 2)) # [C]归一化
105+
for i in range(len(norm_factor)):
106+
norm_factor[i] = 1. / norm_factor[i] if norm_factor[i] > 0. else 0. # 避免除零
107+
alpha = indicate * norm_factor[:, np.newaxis, np.newaxis] # [C,H,W]
108+
109+
weight = np.sum(gradient * alpha, axis=(1, 2)) # [C] alpha*ReLU(gradient)
110+
111+
feature = self.feature[0].cpu().data.numpy() # [C,H,W]
112+
113+
cam = feature * weight[:, np.newaxis, np.newaxis] # [C,H,W]
114+
cam = np.sum(cam, axis=0) # [H,W]
115+
# cam = np.maximum(cam, 0) # ReLU
116+
117+
# 数值归一化
118+
cam -= np.min(cam)
119+
cam /= np.max(cam)
120+
# resize to 224*224
121+
cam = cv2.resize(cam, (224, 224))
122+
return cam

main.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import argparse
1717
from skimage import io
1818
import cv2
19-
from interpretability.grad_cam import GradCAM
19+
from interpretability.grad_cam import GradCAM, GradCamPlusPlus
2020
from interpretability.guided_back_propagation import GuidedBackPropagation
2121

2222

@@ -150,6 +150,12 @@ def main(args):
150150
mask = grad_cam(inputs, args.class_id) # cam mask
151151
image_dict['cam'], image_dict['heatmap'] = gen_cam(img, mask)
152152
grad_cam.remove_handlers()
153+
# Grad-CAM++
154+
grad_cam_plus_plus = GradCamPlusPlus(net, layer_name)
155+
mask_plus_plus = grad_cam_plus_plus(inputs, args.class_id) # cam mask
156+
image_dict['cam++'], image_dict['heatmap++'] = gen_cam(img, mask_plus_plus)
157+
grad_cam_plus_plus.remove_handlers()
158+
153159
# GuidedBackPropagation
154160
gbp = GuidedBackPropagation(net)
155161
inputs.grad.zero_() # 梯度置零

0 commit comments

Comments
 (0)