Skip to content

Commit 025babd

Browse files
committed
add GuidedBackPropagation
1 parent f69da98 commit 025babd

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

interpretability/guided_back_propagation.py

+75-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,78 @@
44
55
@author: mick.yi
66
7-
"""
7+
"""
8+
import torch
9+
from torch import nn
10+
from torch.autograd import Function
11+
import numpy as np
12+
13+
14+
class _GuidedBackPropagationReLUFunction(Function):
15+
"""
16+
定义Function实现自定义的forward和backward
17+
"""
18+
19+
def forward(self, x):
20+
output = torch.clamp(x, min=0)
21+
self.save_for_backward(x, output)
22+
return output
23+
24+
def backward(self, grad_output):
25+
x, _ = self.saved_tensors
26+
forward_mask = (x > 0).type_as(x)
27+
backward_mask = (grad_output > 0.).type_as(grad_output)
28+
grad_input = grad_output * forward_mask * backward_mask
29+
return grad_input
30+
31+
32+
class GuidedBackPropagationReLU(nn.Module):
33+
def __init__(self, **kwargs):
34+
super(GuidedBackPropagationReLU, self).__init__(**kwargs)
35+
36+
def forward(self, x):
37+
return _GuidedBackPropagationReLUFunction().forward(x)
38+
39+
40+
def replace_relu(m):
41+
"""
42+
替换m中所有的ReLU为GuidedBackPropagationReLU
43+
:param m: module
44+
:return:
45+
"""
46+
if len(m._modules) == 0:
47+
return
48+
49+
for name, module in m._modules.items():
50+
if len(module._modules) > 0:
51+
replace_relu(module)
52+
elif isinstance(module, nn.ReLU): # module是最基础的layer
53+
if isinstance(m, nn.Sequential):
54+
m[int(name)] = GuidedBackPropagationReLU()
55+
elif isinstance(m, nn.ModuleList):
56+
m[int(name)] = GuidedBackPropagationReLU()
57+
elif isinstance(m, nn.ModuleDict):
58+
m[name] = GuidedBackPropagationReLU()
59+
elif hasattr(m, name):
60+
setattr(m, name, GuidedBackPropagationReLU())
61+
62+
63+
class GuidedBackPropagation(object):
64+
def __init__(self, net):
65+
self.net = net
66+
self.net.eval()
67+
68+
def __call__(self, inputs, index=None):
69+
"""
70+
71+
:param inputs: [1,3,H,W]
72+
:param index: class_id
73+
:return:
74+
"""
75+
output = self.net(inputs) # [1,num_classes]
76+
if index is None:
77+
index = np.argmax(output.cpu().data.numpy())
78+
target = output[0][index]
79+
target.backward()
80+
81+
return inputs.grad

0 commit comments

Comments
 (0)