-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
48 lines (40 loc) · 2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#%%
import yaml, math, os
with open('config.yaml') as fh:
config = yaml.load(fh, Loader=yaml.FullLoader)
import torch
import torch.nn.functional as F
import torch.nn as nn
from backbone import MSCANet
from decoder import HamDecoder
class SegNext(nn.Module):
def __init__(self, num_classes, in_channnels=3, embed_dims=[32, 64, 460, 256],
ffn_ratios=[4, 4, 4, 4], depths=[3, 3, 5, 2], num_stages=4,
dec_outChannels=256, config=config, dropout=0.0, drop_path=0.0):
super().__init__()
self.cls_conv = nn.Sequential(nn.Dropout2d(p=0.1),
nn.Conv2d(dec_outChannels, num_classes, kernel_size=1))
self.encoder = MSCANet(in_channnels=in_channnels, embed_dims=embed_dims,
ffn_ratios=ffn_ratios, depths=depths, num_stages=num_stages,
dropout=dropout, drop_path=drop_path)
self.decoder = HamDecoder(
outChannels=dec_outChannels, config=config, enc_embed_dims=embed_dims)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, val=1.0)
nn.init.constant_(m.bias, val=0.0)
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
nn.init.normal_(m.weight, std=math.sqrt(2.0/fan_out), mean=0)
def forward(self, x):
enc_feats = self.encoder(x)
dec_out = self.decoder(enc_feats)
output = self.cls_conv(dec_out) # here output will be B x C x H/8 x W/8
output = F.interpolate(output, size=x.size()[-2:], mode='bilinear', align_corners=True) #now its same as input
# bilinear interpol was used originally
return output