4
4
5
5
@author: mick.yi
6
6
7
- """
7
+ """
8
+ import torch
9
+ from torch import nn
10
+ from torch .autograd import Function
11
+ import numpy as np
12
+
13
+
14
+ class _GuidedBackPropagationReLUFunction (Function ):
15
+ """
16
+ 定义Function实现自定义的forward和backward
17
+ """
18
+
19
+ def forward (self , x ):
20
+ output = torch .clamp (x , min = 0 )
21
+ self .save_for_backward (x , output )
22
+ return output
23
+
24
+ def backward (self , grad_output ):
25
+ x , _ = self .saved_tensors
26
+ forward_mask = (x > 0 ).type_as (x )
27
+ backward_mask = (grad_output > 0. ).type_as (grad_output )
28
+ grad_input = grad_output * forward_mask * backward_mask
29
+ return grad_input
30
+
31
+
32
+ class GuidedBackPropagationReLU (nn .Module ):
33
+ def __init__ (self , ** kwargs ):
34
+ super (GuidedBackPropagationReLU , self ).__init__ (** kwargs )
35
+
36
+ def forward (self , x ):
37
+ return _GuidedBackPropagationReLUFunction ().forward (x )
38
+
39
+
40
+ def replace_relu (m ):
41
+ """
42
+ 替换m中所有的ReLU为GuidedBackPropagationReLU
43
+ :param m: module
44
+ :return:
45
+ """
46
+ if len (m ._modules ) == 0 :
47
+ return
48
+
49
+ for name , module in m ._modules .items ():
50
+ if len (module ._modules ) > 0 :
51
+ replace_relu (module )
52
+ elif isinstance (module , nn .ReLU ): # module是最基础的layer
53
+ if isinstance (m , nn .Sequential ):
54
+ m [int (name )] = GuidedBackPropagationReLU ()
55
+ elif isinstance (m , nn .ModuleList ):
56
+ m [int (name )] = GuidedBackPropagationReLU ()
57
+ elif isinstance (m , nn .ModuleDict ):
58
+ m [name ] = GuidedBackPropagationReLU ()
59
+ elif hasattr (m , name ):
60
+ setattr (m , name , GuidedBackPropagationReLU ())
61
+
62
+
63
+ class GuidedBackPropagation (object ):
64
+ def __init__ (self , net ):
65
+ self .net = net
66
+ self .net .eval ()
67
+
68
+ def __call__ (self , inputs , index = None ):
69
+ """
70
+
71
+ :param inputs: [1,3,H,W]
72
+ :param index: class_id
73
+ :return:
74
+ """
75
+ output = self .net (inputs ) # [1,num_classes]
76
+ if index is None :
77
+ index = np .argmax (output .cpu ().data .numpy ())
78
+ target = output [0 ][index ]
79
+ target .backward ()
80
+
81
+ return inputs .grad
0 commit comments