Skip to content

Commit 8371645

Browse files
Completed Model
1 parent 3a37ee1 commit 8371645

File tree

3 files changed

+399
-22
lines changed

3 files changed

+399
-22
lines changed

RepNet/RepNetModel.py

+317-22
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,56 @@
1+
"""
2+
RepNetPeriodEstimator
3+
4+
Credit:
5+
https://openaccess.thecvf.com/content_CVPR_2020/papers/Dwibedi_Counting_Out_Time_Class_Agnostic_Video_Repetition_Counting_in_the_CVPR_2020_paper.pdf
6+
7+
Note:
8+
conv feature extractor is different from that used in original paper
9+
check dilation on temporal 3d conv
10+
pairwise_l2_distance transpose for b
11+
does get sims actually work
12+
input projection kernel_regularizer
13+
transformer_layers_config dff needed ???
14+
"""
15+
116
import torch
217
import torch.nn as nn
3-
from torchvision.models import resnet50
18+
import torch.nn.functional as F
19+
from torchvision.models import resnet50, wide_resnet50_2
20+
21+
from typing import Callable
422

523

624
class RepNetPeriodEstimator(nn.Module):
25+
"""
26+
RepNetPeriodEstimator
27+
"""
728

829
def __init__(self,
9-
num_frames = 64,
10-
image_size = 112,
11-
base_model_layer_name = 'conv4_block3_out',
12-
temperature = 13.544,
13-
dropout_rate = 0.25,
14-
l2_reg_weight = 1e-6,
15-
temporal_conv_channels = 512,
16-
temporal_conv_kernel_size = 3,
17-
temporal_conv_dilation_rate = 3,
18-
conv_channels = 32,
19-
conv_kernel_size = 3,
20-
transformer_layers_config = ((512, 4, 512),),
21-
transformer_dropout_rate = 0.0,
22-
transformer_reorder_ln = True,
23-
period_fc_channels = (512, 512),
24-
within_period_fc_channels = (512, 512)):
30+
num_frames: int = 64,
31+
image_size: int = 112,
32+
temperature: float = 13.544,
33+
dropout_rate: float = 0.25,
34+
temporal_conv_channels: int = 512,
35+
temporal_conv_kernel_size: int = 3,
36+
temporal_conv_dilation_rate: int = 3,
37+
conv_channels: int = 32,
38+
conv_kernel_size: int = 3,
39+
transformer_layers_config: tuple = ((512, 4, 512),),
40+
transformer_dropout_rate: float = 0.0,
41+
transformer_reorder_ln: bool = True,
42+
period_fc_channels: tuple = (512, 512),
43+
within_period_fc_channels: tuple = (512, 512)
44+
):
2545
super(RepNetPeriodEstimator, self).__init__()
2646

2747
# model parameters
2848
self.num_frames = num_frames
2949
self.image_size = image_size
3050

31-
self.base_model_layer_name = base_model_layer_name
32-
3351
self.temperature = temperature
3452

3553
self.dropout_rate = dropout_rate
36-
self.l2_reg_weight = l2_reg_weight
3754

3855
self.temporal_conv_channels = temporal_conv_channels
3956
self.temporal_conv_kernel_size = temporal_conv_kernel_size
@@ -49,5 +66,283 @@ def __init__(self,
4966
self.period_fc_channels = period_fc_channels
5067
self.within_period_fc_channels = within_period_fc_channels
5168

52-
# get resnet50 backbone
53-
69+
# get resnet50 backbone, drop layers down to Conv3 Bottleneck 2
70+
self.base_model = get_base_model(wide = False)
71+
72+
# this is a fix, dilation doesn't work like it does in tf
73+
self.temporal_conv_dilation_rate = 1
74+
75+
# temporal conv layers
76+
self.temporal_conv_layers = [nn.Conv3d(in_channels = 1024,
77+
out_channels = self.temporal_conv_channels,
78+
kernel_size = self.temporal_conv_kernel_size,
79+
padding = 1,
80+
dilation = (self.temporal_conv_dilation_rate, 1, 1)
81+
)]
82+
83+
self.temporal_bn_layers = [nn.BatchNorm3d(num_features = 512) for _ in self.temporal_conv_layers]
84+
85+
self.conv_3x3_layer = nn.Conv2d(in_channels = 1,
86+
out_channels = self.conv_channels,
87+
kernel_size = self.conv_kernel_size,
88+
padding = 1)
89+
90+
channels = self.transformer_layers_config[0][0]
91+
# how many in features
92+
self.input_projection = nn.Linear(in_features = 2048,
93+
out_features = channels,
94+
bias = True
95+
)
96+
97+
self.input_projection2 = nn.Linear(in_features = 2048,
98+
out_features = channels,
99+
bias = True
100+
)
101+
102+
length = self.num_frames
103+
self.pos_encoding = torch.empty(1, length, 1).normal_(mean = 0, std = 0.02)
104+
self.pos_encoding.requires_grad = True
105+
106+
self.pos_encoding2 = torch.empty(1, length, 1).normal_(mean = 0, std = 0.02)
107+
self.pos_encoding2.requires_grad = True
108+
109+
self.transformer_layers = []
110+
for d_model, num_heads, dff in self.transformer_layers_config:
111+
tfel = nn.TransformerEncoderLayer(d_model = d_model,
112+
nhead = num_heads,
113+
dim_feedforward = dff,
114+
dropout = self.transformer_dropout_rate)
115+
self.transformer_layers.append(tfel)
116+
117+
self.transformer_layers2 = []
118+
for d_model, num_heads, dff in self.transformer_layers_config:
119+
tfel = nn.TransformerEncoderLayer(d_model = d_model,
120+
nhead = num_heads,
121+
dim_feedforward = dff,
122+
dropout = self.transformer_dropout_rate)
123+
self.transformer_layers.append(tfel)
124+
125+
# period prediction module
126+
self.dropout_layer = nn.Dropout(self.dropout_rate)
127+
num_preds = self.num_frames // 2
128+
self.fc_layers = []
129+
130+
for channels in self.period_fc_channels:
131+
self.fc_layers.append(
132+
nn.Linear(in_features = channels,
133+
out_features = channels)
134+
135+
)
136+
self.fc_layers.append(nn.ReLU())
137+
138+
self.fc_layers.append(
139+
nn.Linear(in_features = self.period_fc_channels[0],
140+
out_features = num_preds)
141+
)
142+
143+
# Within Period Module
144+
num_preds = 1
145+
self.within_period_fc_layers = []
146+
for channels in self.within_period_fc_channels:
147+
self.within_period_fc_layers.append(
148+
nn.Linear(in_features = channels,
149+
out_features = channels)
150+
)
151+
self.fc_layers.append(nn.ReLU())
152+
self.within_period_fc_layers.append(
153+
nn.Linear(in_features = self.within_period_fc_channels[0],
154+
out_features = num_preds
155+
)
156+
)
157+
158+
def forward(self, x: torch.Tensor) -> tuple:
159+
"""
160+
:param x: input images
161+
:return: x, within_period_x, final_emb
162+
"""
163+
# Ensure usage of correct batch size
164+
batch_size = x.shape[0]
165+
x = torch.reshape(x, [-1, 3, self.image_size, self.image_size])
166+
# Conv feature extractor
167+
x = self.base_model(x)
168+
x = torch.reshape(x, [-1, 1024, 7, 7])
169+
c = x.shape[1]
170+
h = x.shape[2]
171+
w = x.shape[3]
172+
x = torch.reshape(x, [batch_size, c, -1, h, w])
173+
# x = torch.Size([20, 1024, 64, 7, 7])
174+
175+
for bn_layer, conv_layer in zip(self.temporal_bn_layers,
176+
self.temporal_conv_layers):
177+
x = conv_layer(x)
178+
x = bn_layer(x)
179+
F.relu(x)
180+
181+
# x = torch.Size([20, 512, 64, 7, 7])
182+
x, _ = torch.max(x, dim = 3)
183+
x, _ = torch.max(x, dim = 3)
184+
# x = torch.Size([20, 512, 64])
185+
final_embs = x.permute(0, 2, 1)
186+
187+
# get self smimillarity matrix
188+
x = get_sims(x, self.temperature)
189+
# x = torch.Size[20, 64, 64, 1]
190+
x = x.permute([0, 3, 1, 2])
191+
# x = torch.Size[20, 1, 64, 64]
192+
x = F.relu(self.conv_3x3_layer(x))
193+
# x = torch.Size[20, 32, 64, 64]
194+
x = torch.reshape(x, [batch_size, self.num_frames, -1])
195+
# x = torch.Size[20, 64, 2048]
196+
within_period_x = x
197+
198+
# Period prediction
199+
x = self.input_projection(x)
200+
x += self.pos_encoding
201+
# x = torch.Size[20, 64, 512]
202+
for transformer_layer in self.transformer_layers:
203+
x = transformer_layer(x)
204+
# x = torch.Size[20, 64, 512]
205+
# x = torch.Size[20, 64, 512]
206+
x = torch.reshape(x, [batch_size, self.num_frames, -1])
207+
208+
for fc_layer in self.fc_layers:
209+
x = self.dropout_layer(x)
210+
x = fc_layer(x)
211+
212+
# Within period prediction
213+
within_period_x = self.input_projection2(within_period_x)
214+
within_period_x += self.pos_encoding2
215+
216+
for transformer_layer in self.transformer_layers2:
217+
within_period_x = transformer_layer(within_period_x)
218+
within_period_x = torch.reshape(within_period_x, [batch_size, self.num_frames, -1])
219+
220+
for fc_layer in self.within_period_fc_layers:
221+
within_period_x = self.dropout_layer(within_period_x)
222+
within_period_x = fc_layer(within_period_x)
223+
224+
return x, within_period_x, final_embs
225+
226+
def preprocess(self, imgs: torch.Tensor):
227+
"""
228+
Preprocess input images
229+
:param imgs: images to preprocess
230+
:return: preprocessed images
231+
"""
232+
233+
imgs = imgs.float()
234+
imgs -= 127.5
235+
imgs /= 127.5
236+
imgs = F.interpolate(imgs, size = self.image_size)
237+
return imgs
238+
239+
240+
def get_base_model(wide: bool = True):
241+
"""
242+
Get backbone for RepNetEstimator
243+
:param wide: whether to use wide resent 50 or nor
244+
:return: Resnet base model for backbone
245+
"""
246+
247+
if wide:
248+
base_model = wide_resnet50_2(pretrained = False)
249+
else:
250+
base_model = resnet50(pretrained = False)
251+
base_model.fc = nn.Identity()
252+
base_model.avgpool = nn.Identity()
253+
base_model.layer4 = nn.Identity()
254+
base_model.layer3[3] = nn.Identity()
255+
base_model.layer3[4] = nn.Identity()
256+
base_model.layer3[5] = nn.Identity()
257+
return base_model
258+
259+
260+
def pairwise_l2_distance(a: torch.Tensor, b: torch.Tensor):
261+
"""
262+
Computes pairwise distances between all rows of a and all rows of b.
263+
:param a: tensor
264+
:param b: tensor
265+
:return pairwise distance
266+
"""
267+
norm_a = torch.sum(torch.square(a), dim = 0)
268+
norm_a = torch.reshape(norm_a, [-1, 1])
269+
norm_b = torch.sum(torch.square(b), dim = 0)
270+
norm_b = torch.reshape(norm_b, [1, -1])
271+
a = torch.transpose(a, 0, 1)
272+
zero_tensor = torch.zeros(64, 64)
273+
dist = torch.maximum(norm_a - 2.0 * torch.matmul(a, b) + norm_b, zero_tensor)
274+
return dist
275+
276+
277+
def get_sims(embs: torch.Tensor, temperature: float) -> torch.Tensor:
278+
"""
279+
Calculates self-similarity between batch of sequence of embeddings
280+
:param embs: embeddings
281+
:param temperature: temperature
282+
:return self similarity tensor
283+
"""
284+
285+
batch_size = embs.shape[0]
286+
seq_len = embs.shape[2]
287+
embs = torch.reshape(embs, [batch_size, -1, seq_len])
288+
289+
def _get_sims(embs: torch.Tensor):
290+
"""
291+
Calculates self-similarity between sequence of embeddings
292+
:param embs: embeddings
293+
"""
294+
295+
dist = pairwise_l2_distance(embs, embs)
296+
sims = -1.0 * dist
297+
return sims
298+
299+
sims = map_fn(_get_sims, embs)
300+
# sims = torch.Size[20, 64, 64]
301+
sims /= temperature
302+
sims = F.softmax(sims, dim = -1)
303+
sims = sims.unsqueeze(dim = -1)
304+
return sims
305+
306+
307+
def map_fn(fn: Callable, elems: torch.Tensor) -> torch.Tensor:
308+
"""
309+
Transforms elems by applying fn to each element unstacked on dim 0.
310+
:param fn: function to apply
311+
:param elems: tensor to transform
312+
:return: transformed tensor
313+
"""
314+
315+
sims_list = []
316+
for i in range(elems.shape[0]):
317+
sims_list.append(fn(elems[i]))
318+
sims = torch.stack(sims_list)
319+
return sims
320+
321+
322+
def test():
323+
"""
324+
Test for RepNetEstimator Model
325+
:return: nothing
326+
"""
327+
328+
model = RepNetPeriodEstimator()
329+
x = torch.randn(20, 64, 3, 112, 112)
330+
out = model(x)
331+
expected_shapes = [[20, 64, 32],
332+
[20, 64, 1],
333+
[20, 64, 512]]
334+
out_names = ["x", "within_period_x", "final_embs"]
335+
336+
for os, es, ons in zip(out, expected_shapes, out_names):
337+
for i in range(3):
338+
assert os.shape[0] == es[0], "Mismatch in shape for output {} at dim {}, expected {}, got {}".format(
339+
ons[i],
340+
i,
341+
es[0],
342+
os.shape[0])
343+
344+
print("Got all expected shapes, test passed")
345+
346+
347+
if __name__ == "__main__":
348+
test()

utils/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .get_counts import get_counts
2+
3+
__all__ = ["get_counts"]

0 commit comments

Comments
 (0)