Skip to content

Commit f702875

Browse files
author
Krzysztof Maciej Lis
committed
psp multihead
1 parent 1c17c23 commit f702875

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

.gitignore

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
__pycache__/
2-
models/seg_net_bayes.py
1+
2+
__pycache__/
3+
models/seg_net_bayes.py
4+

models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .fcn8s import *
55
from .gcn import *
66
from .psp_net import *
7+
from .psp_net_multihead import *
78
from .seg_net import *
89
from .seg_net_bayes import *
910
from .u_net import *

models/psp_net_multihead.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch import nn
4+
from torchvision import models
5+
6+
from ..utils import initialize_weights
7+
8+
from .psp_net import _PyramidPoolingModule
9+
10+
11+
class PSPHead(nn.Module):
12+
def __init__(self, num_classes):
13+
super().__init__()
14+
15+
self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
16+
self.final = nn.Sequential(
17+
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
18+
nn.BatchNorm2d(512, momentum=.95),
19+
nn.ReLU(inplace=True),
20+
nn.Dropout(0.1),
21+
nn.Conv2d(512, num_classes, kernel_size=1)
22+
)
23+
24+
initialize_weights(self.ppm, self.final)
25+
26+
def forward(self, features_from_backbone, img_size):
27+
result = self.final(self.ppm(features_from_backbone))
28+
return F.interpolate(result, img_size[2:], mode='bilinear')
29+
30+
31+
class PSPNet_Multihead(nn.Module):
32+
def __init__(self, num_heads, num_classes, pretrained=True):
33+
super().__init__()
34+
35+
self.init_heads(num_heads, num_classes=num_classes)
36+
self.init_backbone(pretrained=pretrained)
37+
38+
def init_backbone(self, pretrained):
39+
resnet = models.resnet101(pretrained=pretrained)
40+
41+
for n, m in resnet.layer3.named_modules():
42+
if 'conv2' in n:
43+
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
44+
elif 'downsample.0' in n:
45+
m.stride = (1, 1)
46+
for n, m in resnet.layer4.named_modules():
47+
if 'conv2' in n:
48+
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
49+
elif 'downsample.0' in n:
50+
m.stride = (1, 1)
51+
52+
self.backbone = nn.Sequential(
53+
nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool), # layer 0
54+
resnet.layer1,
55+
resnet.layer2,
56+
resnet.layer3,
57+
resnet.layer4,
58+
)
59+
60+
def init_heads(self, num_heads, num_classes):
61+
62+
self.heads = nn.Sequential(
63+
*[PSPHead(num_classes=num_classes) for i in range(num_heads)]
64+
)
65+
66+
def forward(self, image):
67+
img_size = image.size()
68+
69+
backbone_features = self.backbone(image)
70+
71+
return torch.cat([
72+
head(backbone_features, img_size=img_size)
73+
for head in self.heads
74+
], dim=1)
75+

0 commit comments

Comments
 (0)