1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .utils .model_zoo as model_zoo
4
+ import torch .nn .functional as F
5
+
6
+
7
+ class PPNet (nn .Module ):
8
+
9
+ def __init__ (self , features , img_size , prototype_shape ,
10
+ num_classes , init_weights = True ,
11
+ prototype_activation_function = 'log' ,
12
+ add_on_layers_type = 'bottleneck' ):
13
+
14
+ super (PPNet , self ).__init__ ()
15
+ self .img_size = img_size
16
+ self .prototype_shape = prototype_shape
17
+ self .num_prototypes = prototype_shape [0 ]
18
+ self .num_classes = num_classes
19
+ self .epsilon = 1e-4
20
+
21
+ # prototype_activation_function could be 'log', 'linear',
22
+ # or a generic function that converts distance to similarity score
23
+ self .prototype_activation_function = prototype_activation_function
24
+
25
+ '''
26
+ Here we are initializing the class identities of the prototypes
27
+ Without domain specific knowledge we allocate the same number of
28
+ prototypes for each class
29
+ '''
30
+ assert (self .num_prototypes % self .num_classes == 0 )
31
+ # a onehot indication matrix for each prototype's class identity
32
+ self .prototype_class_identity = torch .zeros (self .num_prototypes ,
33
+ self .num_classes )
34
+
35
+ num_prototypes_per_class = self .num_prototypes // self .num_classes
36
+ for j in range (self .num_prototypes ):
37
+ self .prototype_class_identity [j , j // num_prototypes_per_class ] = 1
38
+
39
+ # this has to be named features to allow the precise loading
40
+ self .features = features
41
+
42
+ if add_on_layers_type == 'bottleneck' :
43
+ add_on_layers = []
44
+ current_in_channels = 512
45
+ while (current_in_channels > self .prototype_shape [1 ]) or (len (add_on_layers ) == 0 ):
46
+ current_out_channels = max (self .prototype_shape [1 ], (current_in_channels // 2 ))
47
+ add_on_layers .append (nn .Conv2d (in_channels = current_in_channels ,
48
+ out_channels = current_out_channels ,
49
+ kernel_size = 1 ))
50
+ add_on_layers .append (nn .ReLU ())
51
+ add_on_layers .append (nn .Conv2d (in_channels = current_out_channels ,
52
+ out_channels = current_out_channels ,
53
+ kernel_size = 1 ))
54
+ if current_out_channels > self .prototype_shape [1 ]:
55
+ add_on_layers .append (nn .ReLU ())
56
+ else :
57
+ assert (current_out_channels == self .prototype_shape [1 ])
58
+ add_on_layers .append (nn .Sigmoid ())
59
+ current_in_channels = current_in_channels // 2
60
+ self .add_on_layers = nn .Sequential (* add_on_layers )
61
+ else :
62
+ self .add_on_layers = nn .Sequential (
63
+ nn .Conv2d (in_channels = 512 , out_channels = self .prototype_shape [1 ],
64
+ kernel_size = 1 ),
65
+ nn .ReLU (),
66
+ nn .Conv2d (in_channels = self .prototype_shape [1 ], out_channels = self .prototype_shape [1 ], kernel_size = 1 ),
67
+ nn .Sigmoid ()
68
+ )
69
+
70
+ self .prototype_vectors = nn .Parameter (torch .rand (self .prototype_shape ),
71
+ requires_grad = True )
72
+
73
+ # do not make this just a tensor,
74
+ # since it will not be moved automatically to gpu
75
+ self .ones = nn .Parameter (torch .ones (self .prototype_shape ),
76
+ requires_grad = False )
77
+
78
+ self .last_layer = nn .Linear (self .num_prototypes , self .num_classes ,
79
+ bias = False ) # do not use bias
80
+
81
+ if init_weights :
82
+ self ._initialize_weights ()
83
+
84
+ def conv_features (self , x ):
85
+ '''
86
+ the feature input to prototype layer
87
+ '''
88
+ x = self .features (x )
89
+ x = self .add_on_layers (x )
90
+ return x
91
+
92
+ @staticmethod
93
+ def _weighted_l2_convolution (input , filter , weights ):
94
+ '''
95
+ input of shape N * c * h * w
96
+ filter of shape P * c * h1 * w1
97
+ weight of shape P * c * h1 * w1
98
+ '''
99
+ input2 = input ** 2
100
+ input_patch_weighted_norm2 = F .conv2d (input = input2 , weight = weights )
101
+
102
+ filter2 = filter ** 2
103
+ weighted_filter2 = filter2 * weights
104
+ filter_weighted_norm2 = torch .sum (weighted_filter2 , dim = (1 , 2 , 3 ))
105
+ filter_weighted_norm2_reshape = filter_weighted_norm2 .view (- 1 , 1 , 1 )
106
+
107
+ weighted_filter = filter * weights
108
+ weighted_inner_product = F .conv2d (input = input , weight = weighted_filter )
109
+
110
+ # use broadcast
111
+ intermediate_result = \
112
+ - 2 * weighted_inner_product + filter_weighted_norm2_reshape
113
+ # x2_patch_sum and intermediate_result are of the same shape
114
+ distances = F .relu (input_patch_weighted_norm2 + intermediate_result )
115
+
116
+ return distances
117
+
118
+ def _l2_convolution (self , x ):
119
+ '''
120
+ apply self.prototype_vectors as l2-convolution filters on input x
121
+ '''
122
+ x2 = x ** 2
123
+ x2_patch_sum = F .conv2d (input = x2 , weight = self .ones )
124
+
125
+ p2 = self .prototype_vectors ** 2
126
+ p2 = torch .sum (p2 , dim = (1 , 2 , 3 ))
127
+ # p2 is a vector of shape (num_prototypes,)
128
+ # then we reshape it to (num_prototypes, 1, 1)
129
+ p2_reshape = p2 .view (- 1 , 1 , 1 )
130
+
131
+ xp = F .conv2d (input = x , weight = self .prototype_vectors )
132
+ intermediate_result = - 2 * xp + p2_reshape # use broadcast
133
+ # x2_patch_sum and intermediate_result are of the same shape
134
+ distances = F .relu (x2_patch_sum + intermediate_result )
135
+
136
+ return distances
137
+
138
+ def prototype_distances (self , x ):
139
+ '''
140
+ x is the raw input
141
+ '''
142
+ conv_features = self .conv_features (x )
143
+ distances = self ._l2_convolution (conv_features )
144
+ return distances
145
+
146
+ def distance_2_similarity (self , distances ):
147
+ if self .prototype_activation_function == 'log' :
148
+ return torch .log ((distances + 1 ) / (distances + self .epsilon ))
149
+ elif self .prototype_activation_function == 'linear' :
150
+ return - distances
151
+ else :
152
+ return self .prototype_activation_function (distances )
153
+
154
+ def forward (self , x ):
155
+ distances = self .prototype_distances (x )
156
+ '''
157
+ we cannot refactor the lines below for similarity scores
158
+ because we need to return min_distances
159
+ '''
160
+ # global min pooling
161
+ min_distances = - F .max_pool2d (- distances ,
162
+ kernel_size = (distances .size ()[2 ],
163
+ distances .size ()[3 ]))
164
+ min_distances = min_distances .view (- 1 , self .num_prototypes )
165
+ prototype_activations = self .distance_2_similarity (min_distances )
166
+ logits = self .last_layer (prototype_activations )
167
+ return logits , min_distances
168
+
169
+ def push_forward (self , x ):
170
+ '''this method is needed for the pushing operation'''
171
+ conv_output = self .conv_features (x )
172
+ distances = self ._l2_convolution (conv_output )
173
+ return conv_output , distances
174
+
175
+ def prune_prototypes (self , prototypes_to_prune ):
176
+ '''
177
+ prototypes_to_prune: a list of indices each in
178
+ [0, current number of prototypes - 1] that indicates the prototypes to
179
+ be removed
180
+ '''
181
+ prototypes_to_keep = list (set (range (self .num_prototypes )) - set (prototypes_to_prune ))
182
+
183
+ self .prototype_vectors = nn .Parameter (self .prototype_vectors .data [prototypes_to_keep , ...],
184
+ requires_grad = True )
185
+
186
+ self .prototype_shape = list (self .prototype_vectors .size ())
187
+ self .num_prototypes = self .prototype_shape [0 ]
188
+
189
+ # changing self.last_layer in place
190
+ # changing in_features and out_features make sure the numbers are consistent
191
+ self .last_layer .in_features = self .num_prototypes
192
+ self .last_layer .out_features = self .num_classes
193
+ self .last_layer .weight .data = self .last_layer .weight .data [:, prototypes_to_keep ]
194
+
195
+ # self.ones is nn.Parameter
196
+ self .ones = nn .Parameter (self .ones .data [prototypes_to_keep , ...],
197
+ requires_grad = False )
198
+ # self.prototype_class_identity is torch tensor
199
+ # so it does not need .data access for value update
200
+ self .prototype_class_identity = self .prototype_class_identity [prototypes_to_keep , :]
201
+
202
+ def set_last_layer_incorrect_connection (self , incorrect_strength ):
203
+ '''
204
+ the incorrect strength will be actual strength if -0.5 then input -0.5
205
+ '''
206
+ positive_one_weights_locations = torch .t (self .prototype_class_identity )
207
+ negative_one_weights_locations = 1 - positive_one_weights_locations
208
+
209
+ correct_class_connection = 1
210
+ incorrect_class_connection = incorrect_strength
211
+ self .last_layer .weight .data .copy_ (
212
+ correct_class_connection * positive_one_weights_locations
213
+ + incorrect_class_connection * negative_one_weights_locations )
214
+
215
+ def _initialize_weights (self ):
216
+ for m in self .add_on_layers .modules ():
217
+ if isinstance (m , nn .Conv2d ):
218
+ # every init technique has an underscore _ in the name
219
+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
220
+
221
+ if m .bias is not None :
222
+ nn .init .constant_ (m .bias , 0 )
223
+
224
+ elif isinstance (m , nn .BatchNorm2d ):
225
+ nn .init .constant_ (m .weight , 1 )
226
+ nn .init .constant_ (m .bias , 0 )
227
+
228
+ self .set_last_layer_incorrect_connection (incorrect_strength = - 0.5 )
229
+
230
+
231
+ def construct_PPNet (bone , img_size = 224 ,
232
+ prototype_shape = (1500 , 512 , 1 , 1 ), num_classes = 15 ,
233
+ prototype_activation_function = 'log' ,
234
+ add_on_layers_type = 'bottleneck' ):
235
+ features = bone
236
+
237
+ return PPNet (features = features ,
238
+ img_size = img_size ,
239
+ prototype_shape = prototype_shape ,
240
+ num_classes = num_classes ,
241
+ init_weights = True ,
242
+ prototype_activation_function = prototype_activation_function ,
243
+ add_on_layers_type = add_on_layers_type )
0 commit comments