1
+ """
2
+ RepNetPeriodEstimator
3
+
4
+ Credit:
5
+ https://openaccess.thecvf.com/content_CVPR_2020/papers/Dwibedi_Counting_Out_Time_Class_Agnostic_Video_Repetition_Counting_in_the_CVPR_2020_paper.pdf
6
+
7
+ Note:
8
+ conv feature extractor is different from that used in original paper
9
+ check dilation on temporal 3d conv
10
+ pairwise_l2_distance transpose for b
11
+ does get sims actually work
12
+ input projection kernel_regularizer
13
+ transformer_layers_config dff needed ???
14
+ """
15
+
1
16
import torch
2
17
import torch .nn as nn
3
- from torchvision .models import resnet50
18
+ import torch .nn .functional as F
19
+ from torchvision .models import resnet50 , wide_resnet50_2
20
+
21
+ from typing import Callable
4
22
5
23
6
24
class RepNetPeriodEstimator (nn .Module ):
25
+ """
26
+ RepNetPeriodEstimator
27
+ """
7
28
8
29
def __init__ (self ,
9
- num_frames = 64 ,
10
- image_size = 112 ,
11
- base_model_layer_name = 'conv4_block3_out' ,
12
- temperature = 13.544 ,
13
- dropout_rate = 0.25 ,
14
- l2_reg_weight = 1e-6 ,
15
- temporal_conv_channels = 512 ,
16
- temporal_conv_kernel_size = 3 ,
17
- temporal_conv_dilation_rate = 3 ,
18
- conv_channels = 32 ,
19
- conv_kernel_size = 3 ,
20
- transformer_layers_config = ((512 , 4 , 512 ),),
21
- transformer_dropout_rate = 0.0 ,
22
- transformer_reorder_ln = True ,
23
- period_fc_channels = (512 , 512 ),
24
- within_period_fc_channels = (512 , 512 )):
30
+ num_frames : int = 64 ,
31
+ image_size : int = 112 ,
32
+ temperature : float = 13.544 ,
33
+ dropout_rate : float = 0.25 ,
34
+ temporal_conv_channels : int = 512 ,
35
+ temporal_conv_kernel_size : int = 3 ,
36
+ temporal_conv_dilation_rate : int = 3 ,
37
+ conv_channels : int = 32 ,
38
+ conv_kernel_size : int = 3 ,
39
+ transformer_layers_config : tuple = ((512 , 4 , 512 ),),
40
+ transformer_dropout_rate : float = 0.0 ,
41
+ transformer_reorder_ln : bool = True ,
42
+ period_fc_channels : tuple = (512 , 512 ),
43
+ within_period_fc_channels : tuple = (512 , 512 )
44
+ ):
25
45
super (RepNetPeriodEstimator , self ).__init__ ()
26
46
27
47
# model parameters
28
48
self .num_frames = num_frames
29
49
self .image_size = image_size
30
50
31
- self .base_model_layer_name = base_model_layer_name
32
-
33
51
self .temperature = temperature
34
52
35
53
self .dropout_rate = dropout_rate
36
- self .l2_reg_weight = l2_reg_weight
37
54
38
55
self .temporal_conv_channels = temporal_conv_channels
39
56
self .temporal_conv_kernel_size = temporal_conv_kernel_size
@@ -49,5 +66,283 @@ def __init__(self,
49
66
self .period_fc_channels = period_fc_channels
50
67
self .within_period_fc_channels = within_period_fc_channels
51
68
52
- # get resnet50 backbone
53
-
69
+ # get resnet50 backbone, drop layers down to Conv3 Bottleneck 2
70
+ self .base_model = get_base_model (wide = False )
71
+
72
+ # this is a fix, dilation doesn't work like it does in tf
73
+ self .temporal_conv_dilation_rate = 1
74
+
75
+ # temporal conv layers
76
+ self .temporal_conv_layers = [nn .Conv3d (in_channels = 1024 ,
77
+ out_channels = self .temporal_conv_channels ,
78
+ kernel_size = self .temporal_conv_kernel_size ,
79
+ padding = 1 ,
80
+ dilation = (self .temporal_conv_dilation_rate , 1 , 1 )
81
+ )]
82
+
83
+ self .temporal_bn_layers = [nn .BatchNorm3d (num_features = 512 ) for _ in self .temporal_conv_layers ]
84
+
85
+ self .conv_3x3_layer = nn .Conv2d (in_channels = 1 ,
86
+ out_channels = self .conv_channels ,
87
+ kernel_size = self .conv_kernel_size ,
88
+ padding = 1 )
89
+
90
+ channels = self .transformer_layers_config [0 ][0 ]
91
+ # how many in features
92
+ self .input_projection = nn .Linear (in_features = 2048 ,
93
+ out_features = channels ,
94
+ bias = True
95
+ )
96
+
97
+ self .input_projection2 = nn .Linear (in_features = 2048 ,
98
+ out_features = channels ,
99
+ bias = True
100
+ )
101
+
102
+ length = self .num_frames
103
+ self .pos_encoding = torch .empty (1 , length , 1 ).normal_ (mean = 0 , std = 0.02 )
104
+ self .pos_encoding .requires_grad = True
105
+
106
+ self .pos_encoding2 = torch .empty (1 , length , 1 ).normal_ (mean = 0 , std = 0.02 )
107
+ self .pos_encoding2 .requires_grad = True
108
+
109
+ self .transformer_layers = []
110
+ for d_model , num_heads , dff in self .transformer_layers_config :
111
+ tfel = nn .TransformerEncoderLayer (d_model = d_model ,
112
+ nhead = num_heads ,
113
+ dim_feedforward = dff ,
114
+ dropout = self .transformer_dropout_rate )
115
+ self .transformer_layers .append (tfel )
116
+
117
+ self .transformer_layers2 = []
118
+ for d_model , num_heads , dff in self .transformer_layers_config :
119
+ tfel = nn .TransformerEncoderLayer (d_model = d_model ,
120
+ nhead = num_heads ,
121
+ dim_feedforward = dff ,
122
+ dropout = self .transformer_dropout_rate )
123
+ self .transformer_layers .append (tfel )
124
+
125
+ # period prediction module
126
+ self .dropout_layer = nn .Dropout (self .dropout_rate )
127
+ num_preds = self .num_frames // 2
128
+ self .fc_layers = []
129
+
130
+ for channels in self .period_fc_channels :
131
+ self .fc_layers .append (
132
+ nn .Linear (in_features = channels ,
133
+ out_features = channels )
134
+
135
+ )
136
+ self .fc_layers .append (nn .ReLU ())
137
+
138
+ self .fc_layers .append (
139
+ nn .Linear (in_features = self .period_fc_channels [0 ],
140
+ out_features = num_preds )
141
+ )
142
+
143
+ # Within Period Module
144
+ num_preds = 1
145
+ self .within_period_fc_layers = []
146
+ for channels in self .within_period_fc_channels :
147
+ self .within_period_fc_layers .append (
148
+ nn .Linear (in_features = channels ,
149
+ out_features = channels )
150
+ )
151
+ self .fc_layers .append (nn .ReLU ())
152
+ self .within_period_fc_layers .append (
153
+ nn .Linear (in_features = self .within_period_fc_channels [0 ],
154
+ out_features = num_preds
155
+ )
156
+ )
157
+
158
+ def forward (self , x : torch .Tensor ) -> tuple :
159
+ """
160
+ :param x: input images
161
+ :return: x, within_period_x, final_emb
162
+ """
163
+ # Ensure usage of correct batch size
164
+ batch_size = x .shape [0 ]
165
+ x = torch .reshape (x , [- 1 , 3 , self .image_size , self .image_size ])
166
+ # Conv feature extractor
167
+ x = self .base_model (x )
168
+ x = torch .reshape (x , [- 1 , 1024 , 7 , 7 ])
169
+ c = x .shape [1 ]
170
+ h = x .shape [2 ]
171
+ w = x .shape [3 ]
172
+ x = torch .reshape (x , [batch_size , c , - 1 , h , w ])
173
+ # x = torch.Size([20, 1024, 64, 7, 7])
174
+
175
+ for bn_layer , conv_layer in zip (self .temporal_bn_layers ,
176
+ self .temporal_conv_layers ):
177
+ x = conv_layer (x )
178
+ x = bn_layer (x )
179
+ F .relu (x )
180
+
181
+ # x = torch.Size([20, 512, 64, 7, 7])
182
+ x , _ = torch .max (x , dim = 3 )
183
+ x , _ = torch .max (x , dim = 3 )
184
+ # x = torch.Size([20, 512, 64])
185
+ final_embs = x .permute (0 , 2 , 1 )
186
+
187
+ # get self smimillarity matrix
188
+ x = get_sims (x , self .temperature )
189
+ # x = torch.Size[20, 64, 64, 1]
190
+ x = x .permute ([0 , 3 , 1 , 2 ])
191
+ # x = torch.Size[20, 1, 64, 64]
192
+ x = F .relu (self .conv_3x3_layer (x ))
193
+ # x = torch.Size[20, 32, 64, 64]
194
+ x = torch .reshape (x , [batch_size , self .num_frames , - 1 ])
195
+ # x = torch.Size[20, 64, 2048]
196
+ within_period_x = x
197
+
198
+ # Period prediction
199
+ x = self .input_projection (x )
200
+ x += self .pos_encoding
201
+ # x = torch.Size[20, 64, 512]
202
+ for transformer_layer in self .transformer_layers :
203
+ x = transformer_layer (x )
204
+ # x = torch.Size[20, 64, 512]
205
+ # x = torch.Size[20, 64, 512]
206
+ x = torch .reshape (x , [batch_size , self .num_frames , - 1 ])
207
+
208
+ for fc_layer in self .fc_layers :
209
+ x = self .dropout_layer (x )
210
+ x = fc_layer (x )
211
+
212
+ # Within period prediction
213
+ within_period_x = self .input_projection2 (within_period_x )
214
+ within_period_x += self .pos_encoding2
215
+
216
+ for transformer_layer in self .transformer_layers2 :
217
+ within_period_x = transformer_layer (within_period_x )
218
+ within_period_x = torch .reshape (within_period_x , [batch_size , self .num_frames , - 1 ])
219
+
220
+ for fc_layer in self .within_period_fc_layers :
221
+ within_period_x = self .dropout_layer (within_period_x )
222
+ within_period_x = fc_layer (within_period_x )
223
+
224
+ return x , within_period_x , final_embs
225
+
226
+ def preprocess (self , imgs : torch .Tensor ):
227
+ """
228
+ Preprocess input images
229
+ :param imgs: images to preprocess
230
+ :return: preprocessed images
231
+ """
232
+
233
+ imgs = imgs .float ()
234
+ imgs -= 127.5
235
+ imgs /= 127.5
236
+ imgs = F .interpolate (imgs , size = self .image_size )
237
+ return imgs
238
+
239
+
240
+ def get_base_model (wide : bool = True ):
241
+ """
242
+ Get backbone for RepNetEstimator
243
+ :param wide: whether to use wide resent 50 or nor
244
+ :return: Resnet base model for backbone
245
+ """
246
+
247
+ if wide :
248
+ base_model = wide_resnet50_2 (pretrained = False )
249
+ else :
250
+ base_model = resnet50 (pretrained = False )
251
+ base_model .fc = nn .Identity ()
252
+ base_model .avgpool = nn .Identity ()
253
+ base_model .layer4 = nn .Identity ()
254
+ base_model .layer3 [3 ] = nn .Identity ()
255
+ base_model .layer3 [4 ] = nn .Identity ()
256
+ base_model .layer3 [5 ] = nn .Identity ()
257
+ return base_model
258
+
259
+
260
+ def pairwise_l2_distance (a : torch .Tensor , b : torch .Tensor ):
261
+ """
262
+ Computes pairwise distances between all rows of a and all rows of b.
263
+ :param a: tensor
264
+ :param b: tensor
265
+ :return pairwise distance
266
+ """
267
+ norm_a = torch .sum (torch .square (a ), dim = 0 )
268
+ norm_a = torch .reshape (norm_a , [- 1 , 1 ])
269
+ norm_b = torch .sum (torch .square (b ), dim = 0 )
270
+ norm_b = torch .reshape (norm_b , [1 , - 1 ])
271
+ a = torch .transpose (a , 0 , 1 )
272
+ zero_tensor = torch .zeros (64 , 64 )
273
+ dist = torch .maximum (norm_a - 2.0 * torch .matmul (a , b ) + norm_b , zero_tensor )
274
+ return dist
275
+
276
+
277
+ def get_sims (embs : torch .Tensor , temperature : float ) -> torch .Tensor :
278
+ """
279
+ Calculates self-similarity between batch of sequence of embeddings
280
+ :param embs: embeddings
281
+ :param temperature: temperature
282
+ :return self similarity tensor
283
+ """
284
+
285
+ batch_size = embs .shape [0 ]
286
+ seq_len = embs .shape [2 ]
287
+ embs = torch .reshape (embs , [batch_size , - 1 , seq_len ])
288
+
289
+ def _get_sims (embs : torch .Tensor ):
290
+ """
291
+ Calculates self-similarity between sequence of embeddings
292
+ :param embs: embeddings
293
+ """
294
+
295
+ dist = pairwise_l2_distance (embs , embs )
296
+ sims = - 1.0 * dist
297
+ return sims
298
+
299
+ sims = map_fn (_get_sims , embs )
300
+ # sims = torch.Size[20, 64, 64]
301
+ sims /= temperature
302
+ sims = F .softmax (sims , dim = - 1 )
303
+ sims = sims .unsqueeze (dim = - 1 )
304
+ return sims
305
+
306
+
307
+ def map_fn (fn : Callable , elems : torch .Tensor ) -> torch .Tensor :
308
+ """
309
+ Transforms elems by applying fn to each element unstacked on dim 0.
310
+ :param fn: function to apply
311
+ :param elems: tensor to transform
312
+ :return: transformed tensor
313
+ """
314
+
315
+ sims_list = []
316
+ for i in range (elems .shape [0 ]):
317
+ sims_list .append (fn (elems [i ]))
318
+ sims = torch .stack (sims_list )
319
+ return sims
320
+
321
+
322
+ def test ():
323
+ """
324
+ Test for RepNetEstimator Model
325
+ :return: nothing
326
+ """
327
+
328
+ model = RepNetPeriodEstimator ()
329
+ x = torch .randn (20 , 64 , 3 , 112 , 112 )
330
+ out = model (x )
331
+ expected_shapes = [[20 , 64 , 32 ],
332
+ [20 , 64 , 1 ],
333
+ [20 , 64 , 512 ]]
334
+ out_names = ["x" , "within_period_x" , "final_embs" ]
335
+
336
+ for os , es , ons in zip (out , expected_shapes , out_names ):
337
+ for i in range (3 ):
338
+ assert os .shape [0 ] == es [0 ], "Mismatch in shape for output {} at dim {}, expected {}, got {}" .format (
339
+ ons [i ],
340
+ i ,
341
+ es [0 ],
342
+ os .shape [0 ])
343
+
344
+ print ("Got all expected shapes, test passed" )
345
+
346
+
347
+ if __name__ == "__main__" :
348
+ test ()
0 commit comments