-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathgrad_cam.py
45 lines (32 loc) · 1.3 KB
/
grad_cam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import cv2
import numpy as np
class GradCAM:
def __init__(self, model, target_layer):
self.model = model.eval()
self.featuremaps = []
self.gradients = []
target_layer.register_forward_hook(self.save_featuremaps)
target_layer.register_backward_hook(self.save_gradients)
def save_featuremaps(self, module, input, output):
self.featuremaps.append(output)
def save_gradients(self, module, grad_input, grad_output):
self.gradients.append(grad_output[0])
def get_cam_weights(self, grads):
return np.mean(grads, axis=(1, 2))
def __call__(self, image, label=None):
preds = self.model(image)
self.model.zero_grad()
if label is None:
label = preds.argmax(dim=1).item()
preds[:, label].backward()
featuremaps = self.featuremaps[-1].cpu().data.numpy()[0, :]
gradients = self.gradients[-1].cpu().data.numpy()[0, :]
weights = self.get_cam_weights(gradients)
cam = np.zeros(featuremaps.shape[1:], dtype=np.float32)
for i, w in enumerate(weights):
cam += w * featuremaps[i]
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, image.shape[-2:][::-1])
cam = cam - np.min(cam)
cam = cam / np.max(cam)
return label, cam