Skip to content

Commit 1c17c23

Browse files
author
Krzysztof Maciej Lis
committed
bayesian segnet
1 parent 071a59d commit 1c17c23

File tree

6 files changed

+111
-8
lines changed

6 files changed

+111
-8
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__pycache__/
2+
models/seg_net_bayes.py

models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .gcn import *
66
from .psp_net import *
77
from .seg_net import *
8+
from .seg_net_bayes import *
89
from .u_net import *

models/psp_net.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class _PyramidPoolingModule(nn.Module):
1111
def __init__(self, in_dim, reduction_dim, setting):
12-
super(_PyramidPoolingModule, self).__init__()
12+
super().__init__()
1313
self.features = []
1414
for s in setting:
1515
self.features.append(nn.Sequential(
@@ -24,14 +24,14 @@ def forward(self, x):
2424
x_size = x.size()
2525
out = [x]
2626
for f in self.features:
27-
out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
27+
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear'))
2828
out = torch.cat(out, 1)
2929
return out
3030

3131

3232
class PSPNet(nn.Module):
3333
def __init__(self, num_classes, pretrained=True, use_aux=True):
34-
super(PSPNet, self).__init__()
34+
super().__init__()
3535
self.use_aux = use_aux
3636
resnet = models.resnet101(pretrained=pretrained)
3737
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
@@ -63,7 +63,8 @@ def __init__(self, num_classes, pretrained=True, use_aux=True):
6363

6464
initialize_weights(self.ppm, self.final)
6565

66-
def forward(self, x):
66+
def forward(self, image):
67+
x = image
6768
x_size = x.size()
6869
x = self.layer0(x)
6970
x = self.layer1(x)
@@ -75,8 +76,8 @@ def forward(self, x):
7576
x = self.ppm(x)
7677
x = self.final(x)
7778
if self.training and self.use_aux:
78-
return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
79-
return F.upsample(x, x_size[2:], mode='bilinear')
79+
return F.interpolate(x, x_size[2:], mode='bilinear'), F.interpolate(aux, x_size[2:], mode='bilinear')
80+
return F.interpolate(x, x_size[2:], mode='bilinear')
8081

8182

8283
# just a try, not recommend to use

models/seg_net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class _DecoderBlock(nn.Module):
99
def __init__(self, in_channels, out_channels, num_conv_layers):
1010
super(_DecoderBlock, self).__init__()
11-
middle_channels = in_channels / 2
11+
middle_channels = in_channels // 2
1212
layers = [
1313
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2),
1414
nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),

models/seg_net_bayes.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
import torch
3+
from torch import nn
4+
from torchvision import models
5+
6+
from ..utils import initialize_weights
7+
from .seg_net import _DecoderBlock, SegNet
8+
9+
class SegNetBayes(SegNet):
10+
def __init__(self, num_classes, dropout_p=0.5, pretrained=True, num_samples=16, min_batch_size=4):
11+
"""
12+
@param num_samples: number of samples for the Monte-Carlo simulation,
13+
how many times to run with random dropout
14+
"""
15+
super().__init__(num_classes=num_classes, pretrained=pretrained)
16+
17+
self.drop = nn.Dropout2d(p=dropout_p, inplace=False)
18+
self.num_samples = num_samples
19+
self.min_batch_size = min_batch_size
20+
21+
def forward(self, x):
22+
enc1 = self.enc1(x)
23+
#print('enc1', enc1.shape)
24+
25+
enc2 = self.enc2(enc1)
26+
#print('enc2', enc2.shape)
27+
28+
enc3 = self.enc3(enc2)
29+
#print('enc3', enc3.shape)
30+
enc3 = self.drop(enc3)
31+
#print('enc3d', enc3.shape)
32+
33+
enc4 = self.enc4(enc3)
34+
#print('enc4', enc4.shape)
35+
enc4 = self.drop(enc4)
36+
#print('enc4d', enc4.shape)
37+
38+
enc5 = self.enc5(enc4)
39+
#print('enc5', enc5.shape)
40+
enc5 = self.drop(enc5)
41+
#print('enc5d', enc5.shape)
42+
43+
dec5 = self.dec5(enc5)
44+
#print('dec5', dec5.shape)
45+
dec5 = self.drop(dec5)
46+
#print('dec5d', dec5.shape)
47+
48+
dec4 = self.dec4(torch.cat([enc4, dec5], 1))
49+
#print('dec4', dec4.shape)
50+
dec4 = self.drop(dec4)
51+
#print('dec4d', dec4.shape)
52+
53+
dec3 = self.dec3(torch.cat([enc3, dec4], 1))
54+
dec3 = self.drop(dec3)
55+
56+
dec2 = self.dec2(torch.cat([enc2, dec3], 1))
57+
dec1 = self.dec1(torch.cat([enc1, dec2], 1))
58+
return dec1
59+
60+
def forward_multisample(self, x, num_samples=None):
61+
# dropout must be active
62+
backup_train_mode = self.drop.training
63+
self.drop.train()
64+
65+
softmax = torch.nn.Softmax2d()
66+
67+
num_samples = num_samples if num_samples else self.num_samples
68+
69+
results = [softmax(self.forward(x)).data.cpu() for i in range(num_samples)]
70+
71+
preds = torch.stack(results).cuda()
72+
avg = torch.mean(preds, 0)
73+
var = torch.var(preds, 0)
74+
del preds
75+
76+
# restore mode
77+
self.drop.train(backup_train_mode)
78+
79+
return dict(
80+
mean = avg,
81+
var = var,
82+
)
83+
84+
#def sample(self, x, num_samples, batch_size):
85+
#infer desired batch size from input shape
86+
#we will divide a num_samples into batches
87+
#num_frames = x.shape[0]
88+
#batch_size = max(num_frames, self.min_batch_size)
89+
90+
#for sample_idx in range(num_samples):
91+
#pred =
92+
93+
94+
#for fr_idx in range(num_frames):
95+
#x_single = x[fr_idx:fr_idx+1, :, :, :]
96+
#self.sample(x_single, num_samples, batch_size)
97+
98+
99+

utils/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def initialize_weights(*models):
1717
for model in models:
1818
for module in model.modules():
1919
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
20-
nn.init.kaiming_normal(module.weight)
20+
nn.init.kaiming_normal_(module.weight)
2121
if module.bias is not None:
2222
module.bias.data.zero_()
2323
elif isinstance(module, nn.BatchNorm2d):

0 commit comments

Comments
 (0)