-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathgradcam.py
122 lines (98 loc) · 4.42 KB
/
gradcam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2020 Samson Woof
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras import Model
def grad_cam(model, img,
layer_name="block5_conv3", label_name=None,
category_id=None):
"""Get a heatmap by Grad-CAM.
Args:
model: A model object, build from tf.keras 2.X.
img: An image ndarray.
layer_name: A string, layer name in model.
label_name: A list or None,
show the label name by assign this argument,
it should be a list of all label names.
category_id: An integer, index of the class.
Default is the category with the highest score in the prediction.
Return:
A heatmap ndarray(without color).
"""
img_tensor = np.expand_dims(img, axis=0)
conv_layer = model.get_layer(layer_name)
heatmap_model = Model([model.inputs], [conv_layer.output, model.output])
with tf.GradientTape() as gtape:
conv_output, predictions = heatmap_model(img_tensor)
if category_id is None:
category_id = np.argmax(predictions[0])
if label_name is not None:
print(label_name[category_id])
output = predictions[:, category_id]
grads = gtape.gradient(output, conv_output)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
heatmap = tf.reduce_mean(tf.multiply(pooled_grads, conv_output), axis=-1)
heatmap = np.maximum(heatmap, 0)
max_heat = np.max(heatmap)
if max_heat == 0:
max_heat = 1e-10
heatmap /= max_heat
return np.squeeze(heatmap)
def grad_cam_plus(model, img,
layer_name="block5_conv3", label_name=None,
category_id=None):
"""Get a heatmap by Grad-CAM++.
Args:
model: A model object, build from tf.keras 2.X.
img: An image ndarray.
layer_name: A string, layer name in model.
label_name: A list or None,
show the label name by assign this argument,
it should be a list of all label names.
category_id: An integer, index of the class.
Default is the category with the highest score in the prediction.
Return:
A heatmap ndarray(without color).
"""
img_tensor = np.expand_dims(img, axis=0)
conv_layer = model.get_layer(layer_name)
heatmap_model = Model([model.inputs], [conv_layer.output, model.output])
with tf.GradientTape() as gtape1:
with tf.GradientTape() as gtape2:
with tf.GradientTape() as gtape3:
conv_output, predictions = heatmap_model(img_tensor)
if category_id is None:
category_id = np.argmax(predictions[0])
if label_name is not None:
print(label_name[category_id])
output = predictions[:, category_id]
conv_first_grad = gtape3.gradient(output, conv_output)
conv_second_grad = gtape2.gradient(conv_first_grad, conv_output)
conv_third_grad = gtape1.gradient(conv_second_grad, conv_output)
global_sum = np.sum(conv_output, axis=(0, 1, 2))
alpha_num = conv_second_grad[0]
alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum
alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, 1e-10)
alphas = alpha_num/alpha_denom
alpha_normalization_constant = np.sum(alphas, axis=(0,1))
alphas /= alpha_normalization_constant
weights = np.maximum(conv_first_grad[0], 0.0)
deep_linearization_weights = np.sum(weights*alphas, axis=(0,1))
grad_cam_map = np.sum(deep_linearization_weights*conv_output[0], axis=2)
heatmap = np.maximum(grad_cam_map, 0)
max_heat = np.max(heatmap)
if max_heat == 0:
max_heat = 1e-10
heatmap /= max_heat
return heatmap