-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsegtransformer.py
285 lines (232 loc) · 10.5 KB
/
segtransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import numpy as np
# DL library imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torchstat import stat
# libraries for loading image, plotting
import cv2
import matplotlib.pyplot as plt
from einops import rearrange
from timm.models.layers import drop_path, trunc_normal_
### Mix Transformer / Encoder
"""
First, we'll gather the ingredients from a single transformer stage:
* Overlap Patch Embedding
* Efficient Self-Attention
* Mix FFNs
**The 3 elements form a single Transformer Block.**
"""
class overlap_patch_embed(nn.Module):
def __init__(self, patch_size, stride, in_chans, embed_dim):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size // 2, patch_size // 2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
_, _, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.norm(x)
return x, h, w
class efficient_self_attention(nn.Module):
def __init__(self, attn_dim, num_heads, dropout_p, sr_ratio):
super().__init__()
assert attn_dim % num_heads == 0, f'expected attn_dim {attn_dim} to be a multiple of num_heads {num_heads}'
self.attn_dim = attn_dim
self.num_heads = num_heads
self.dropout_p = dropout_p
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(attn_dim, attn_dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(attn_dim)
# Multi-head Self-Attention using dot product
# Query - Key Dot product is scaled by root of head_dim
self.q = nn.Linear(attn_dim, attn_dim, bias=True)
self.kv = nn.Linear(attn_dim, attn_dim * 2, bias=True)
self.scale = (attn_dim // num_heads) ** -0.5
# Projecting concatenated outputs from
# multiple heads to single `attn_dim` size
self.proj = nn.Linear(attn_dim, attn_dim)
def forward(self, x, h, w):
q = self.q(x)
q = rearrange(q, ('b hw (m c) -> b m hw c'), m=self.num_heads)
if self.sr_ratio > 1:
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.sr(x)
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.norm(x)
x = self.kv(x)
x = rearrange(x, 'b d (a m c) -> a b m d c', a=2, m=self.num_heads)
k, v = x[0], x[1] # x.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = attn @ v
x = rearrange(x, 'b m hw c -> b hw (m c)')
x = self.proj(x)
x = F.dropout(x, p=self.dropout_p, training=self.training)
return x
class mix_feedforward(nn.Module):
def __init__(self, in_features, out_features, hidden_features, dropout_p = 0.0):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
# Depth-wise separable convolution
self.conv = nn.Conv2d(hidden_features, hidden_features, (3, 3), padding=(1, 1),
bias=True, groups=hidden_features)
self.dropout_p = dropout_p
def forward(self, x, h, w):
x = self.fc1(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.conv(x)
x = rearrange(x, 'b c h w -> b (h w) c')
x = F.gelu(x)
x = F.dropout(x, p=self.dropout_p, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout_p, training=self.training)
return x
class transformer_block(nn.Module):
def __init__(self, dim, num_heads, dropout_p, drop_path_p, sr_ratio):
super().__init__()
# One transformer block is defined as :
# Norm -> self-attention -> Norm -> FeedForward
# skip-connections are added after attention and FF layers
self.attn = efficient_self_attention(attn_dim=dim, num_heads=num_heads,
dropout_p=dropout_p, sr_ratio=sr_ratio)
self.ffn = mix_feedforward( dim, dim, hidden_features=dim * 4, dropout_p=dropout_p)
self.drop_path_p = drop_path_p
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
def forward(self, x, h, w):
# Norm -> self-attention
skip = x
x = self.norm1(x)
x = self.attn(x, h, w)
x = drop_path(x, drop_prob=self.drop_path_p, training=self.training)
x = x + skip
# Norm -> FeedForward
skip = x
x = self.norm2(x)
x = self.ffn(x, h, w)
x = drop_path(x, drop_prob=self.drop_path_p, training=self.training)
x = x + skip
return x
"""Then, we'll build a set of transformer stages: each stage is Nx Transformer Blocks:"""
class mix_transformer_stage(nn.Module):
def __init__(self, patch_embed, blocks, norm):
super().__init__()
self.patch_embed = patch_embed
self.blocks = nn.ModuleList(blocks)
self.norm = norm
def forward(self, x):
x, h, w = self.patch_embed(x)
for block in self.blocks:
x = block(x, h, w)
x = self.norm(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
return x
"""Finally, we create the full decoder, which is a set of 4 consecutive transformer stages (what's incorrectly called "block" in the image). **Each Transformer Stage is made of Nx Transformer Blocks.** Each block is made of Efficient Self-Attention, Mix FFNs, and Overlap Patch Merging."""
class mix_transformer(nn.Module):
def __init__(self, in_chans, embed_dims, num_heads, depths,
sr_ratios, dropout_p, drop_path_p):
super().__init__()
self.stages = nn.ModuleList()
for stage_i in range(len(depths)):
# Each Stage consists of following blocks :
# Overlap patch embedding -> mix_transformer_block -> norm
blocks = []
for i in range(depths[stage_i]):
blocks.append(transformer_block(dim = embed_dims[stage_i],
num_heads= num_heads[stage_i], dropout_p=dropout_p,
drop_path_p = drop_path_p * (sum(depths[:stage_i])+i) / (sum(depths)-1),
sr_ratio = sr_ratios[stage_i] ))
if(stage_i == 0):
patch_size = 7
stride = 4
in_chans = in_chans
else:
patch_size = 3
stride = 2
in_chans = embed_dims[stage_i -1]
patch_embed = overlap_patch_embed(patch_size, stride=stride, in_chans=in_chans,
embed_dim= embed_dims[stage_i])
norm = nn.LayerNorm(embed_dims[stage_i], eps=1e-6)
self.stages.append(mix_transformer_stage(patch_embed, blocks, norm))
def forward(self, x):
outputs = []
for stage in self.stages:
x = stage(x)
outputs.append(x)
return outputs
"""### Decoder Head"""
class segformer_head(nn.Module):
def __init__(self, in_channels, num_classes, embed_dim, dropout_p=0.1):
super().__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.embed_dim = embed_dim
self.dropout_p = dropout_p
# 1x1 conv to fuse multi-scale output from encoder
self.layers = nn.ModuleList([nn.Conv2d(chans, embed_dim, (1, 1))
for chans in reversed(in_channels)])
self.linear_fuse = nn.Conv2d(embed_dim * len(self.layers), embed_dim, (1, 1), bias=False)
self.bn = nn.BatchNorm2d(embed_dim, eps=1e-5)
# 1x1 conv to get num_class channel predictions
self.linear_pred = nn.Conv2d(self.embed_dim, num_classes, kernel_size=(1, 1))
self.init_weights()
def init_weights(self):
nn.init.kaiming_normal_(self.linear_fuse.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(self.bn.weight, 1)
nn.init.constant_(self.bn.bias, 0)
def forward(self, x):
feature_size = x[0].shape[2:]
# project each encoder stage output to H/4, W/4
x = [layer(xi) for layer, xi in zip(self.layers, reversed(x))]
x = [F.interpolate(xi, size=feature_size, mode='bilinear', align_corners=False)
for xi in x[:-1]] + [x[-1]]
# concatenate project output and use 1x1
# convs to get num_class channel output
x = self.linear_fuse(torch.cat(x, dim=1))
x = self.bn(x)
x = F.relu(x, inplace=True)
x = F.dropout(x, p=self.dropout_p, training=self.training)
x = self.linear_pred(x)
return x
"""### Full Segformer"""
class segformer_mit_b3(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
# Encoder block
self.backbone = mix_transformer(in_chans=in_channels, embed_dims=(64, 128, 320, 512),
num_heads=(1, 2, 5, 8), depths=(3, 4, 18, 3),
sr_ratios=(8, 4, 2, 1), dropout_p=0.0, drop_path_p=0.1)
# decoder block
self.decoder_head = segformer_head(in_channels=(64, 128, 320, 512),
num_classes=num_classes, embed_dim=256)
# init weights
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
def forward(self, x):
image_hw = x.shape[2:]
x = self.backbone(x)
x = self.decoder_head(x)
x = F.interpolate(x, size=image_hw, mode='bilinear', align_corners=False)
return x
if __name__ == "__main__":
NUM_CLASSES = 30
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = segformer_mit_b3(in_channels=3, num_classes=NUM_CLASSES).to(device)
print(model)
summary(model, (3, 1024, 512))
stat(model.to("cpu"), (3, 1024, 512))