forked from bnsreenu/python_for_microscopists
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path224_225_226_models.py
359 lines (283 loc) · 15.2 KB
/
224_225_226_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
# https://youtu.be/L5iV5BHkMzM
"""
Attention U-net:
https://arxiv.org/pdf/1804.03999.pdf
Recurrent residual Unet (R2U-Net) paper
https://arxiv.org/ftp/arxiv/papers/1802/1802.06955.pdf
(Check fig 4.)
Note: Batch normalization should be performed over channels after a convolution,
In the following code axis is set to 3 as our inputs are of shape
[None, height, width, channel]. Channel is axis=3.
Original code from below link but heavily modified.
https://github.com/MoleImg/Attention_UNet/blob/master/AttResUNet.py
"""
import tensorflow as tf
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras import backend as K
'''
A few useful metrics and losses
'''
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)
def jacard_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
def jacard_coef_loss(y_true, y_pred):
return -jacard_coef(y_true, y_pred)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
##############################################################
'''
Useful blocks to build Unet
conv - BN - Activation - conv - BN - Activation - Dropout (if enabled)
'''
def conv_block(x, filter_size, size, dropout, batch_norm=False):
conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)
if batch_norm is True:
conv = layers.BatchNormalization(axis=3)(conv)
conv = layers.Activation("relu")(conv)
conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(conv)
if batch_norm is True:
conv = layers.BatchNormalization(axis=3)(conv)
conv = layers.Activation("relu")(conv)
if dropout > 0:
conv = layers.Dropout(dropout)(conv)
return conv
def repeat_elem(tensor, rep):
# lambda function to repeat Repeats the elements of a tensor along an axis
#by a factor of rep.
# If tensor has shape (None, 256,256,3), lambda will return a tensor of shape
#(None, 256,256,6), if specified axis=3 and rep=2.
return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
arguments={'repnum': rep})(tensor)
def res_conv_block(x, filter_size, size, dropout, batch_norm=False):
'''
Residual convolutional layer.
Two variants....
Either put activation function before the addition with shortcut
or after the addition (which would be as proposed in the original resNet).
1. conv - BN - Activation - conv - BN - Activation
- shortcut - BN - shortcut+BN
2. conv - BN - Activation - conv - BN
- shortcut - BN - shortcut+BN - Activation
Check fig 4 in https://arxiv.org/ftp/arxiv/papers/1802/1802.06955.pdf
'''
conv = layers.Conv2D(size, (filter_size, filter_size), padding='same')(x)
if batch_norm is True:
conv = layers.BatchNormalization(axis=3)(conv)
conv = layers.Activation('relu')(conv)
conv = layers.Conv2D(size, (filter_size, filter_size), padding='same')(conv)
if batch_norm is True:
conv = layers.BatchNormalization(axis=3)(conv)
#conv = layers.Activation('relu')(conv) #Activation before addition with shortcut
if dropout > 0:
conv = layers.Dropout(dropout)(conv)
shortcut = layers.Conv2D(size, kernel_size=(1, 1), padding='same')(x)
if batch_norm is True:
shortcut = layers.BatchNormalization(axis=3)(shortcut)
res_path = layers.add([shortcut, conv])
res_path = layers.Activation('relu')(res_path) #Activation after addition with shortcut (Original residual block)
return res_path
def gating_signal(input, out_size, batch_norm=False):
"""
resize the down layer feature map into the same dimension as the up layer feature map
using 1x1 conv
:return: the gating feature map with the same dimension of the up layer feature map
"""
x = layers.Conv2D(out_size, (1, 1), padding='same')(input)
if batch_norm:
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
return x
def attention_block(x, gating, inter_shape):
shape_x = K.int_shape(x)
shape_g = K.int_shape(gating)
# Getting the x signal to the same shape as the gating signal
theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x) # 16
shape_theta_x = K.int_shape(theta_x)
# Getting the gating signal to the same number of filters as the inter_shape
phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3),
strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
padding='same')(phi_g) # 16
concat_xg = layers.add([upsample_g, theta_x])
act_xg = layers.Activation('relu')(concat_xg)
psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)
sigmoid_xg = layers.Activation('sigmoid')(psi)
shape_sigmoid = K.int_shape(sigmoid_xg)
upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg) # 32
upsample_psi = repeat_elem(upsample_psi, shape_x[3])
y = layers.multiply([upsample_psi, x])
result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)
result_bn = layers.BatchNormalization()(result)
return result_bn
def UNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
'''
UNet,
'''
# network structure
FILTER_NUM = 64 # number of filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters
inputs = layers.Input(input_shape, dtype=tf.float32)
# Downsampling layers
# DownRes 1, convolution + pooling
conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
# DownRes 2
conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
# DownRes 3
conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32)
# DownRes 4
conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16)
# DownRes 5, convolution only
conv_8 = conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm)
# Upsampling layers
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8)
up_16 = layers.concatenate([up_16, conv_16], axis=3)
up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 7
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16)
up_32 = layers.concatenate([up_32, conv_32], axis=3)
up_conv_32 = conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 8
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32)
up_64 = layers.concatenate([up_64, conv_64], axis=3)
up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 9
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64)
up_128 = layers.concatenate([up_128, conv_128], axis=3)
up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
# 1*1 convolutional layers
conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_128)
conv_final = layers.BatchNormalization(axis=3)(conv_final)
conv_final = layers.Activation('sigmoid')(conv_final) #Change to softmax for multichannel
# Model
model = models.Model(inputs, conv_final, name="UNet")
print(model.summary())
return model
def Attention_UNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
'''
Attention UNet,
'''
# network structure
FILTER_NUM = 64 # number of basic filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters
inputs = layers.Input(input_shape, dtype=tf.float32)
# Downsampling layers
# DownRes 1, convolution + pooling
conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
# DownRes 2
conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
# DownRes 3
conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32)
# DownRes 4
conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16)
# DownRes 5, convolution only
conv_8 = conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm)
# Upsampling layers
# UpRes 6, attention gated concatenation + upsampling + double residual convolution
gating_16 = gating_signal(conv_8, 8*FILTER_NUM, batch_norm)
att_16 = attention_block(conv_16, gating_16, 8*FILTER_NUM)
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8)
up_16 = layers.concatenate([up_16, att_16], axis=3)
up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 7
gating_32 = gating_signal(up_conv_16, 4*FILTER_NUM, batch_norm)
att_32 = attention_block(conv_32, gating_32, 4*FILTER_NUM)
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16)
up_32 = layers.concatenate([up_32, att_32], axis=3)
up_conv_32 = conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 8
gating_64 = gating_signal(up_conv_32, 2*FILTER_NUM, batch_norm)
att_64 = attention_block(conv_64, gating_64, 2*FILTER_NUM)
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32)
up_64 = layers.concatenate([up_64, att_64], axis=3)
up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 9
gating_128 = gating_signal(up_conv_64, FILTER_NUM, batch_norm)
att_128 = attention_block(conv_128, gating_128, FILTER_NUM)
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64)
up_128 = layers.concatenate([up_128, att_128], axis=3)
up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
# 1*1 convolutional layers
conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_128)
conv_final = layers.BatchNormalization(axis=3)(conv_final)
conv_final = layers.Activation('sigmoid')(conv_final) #Change to softmax for multichannel
# Model integration
model = models.Model(inputs, conv_final, name="Attention_UNet")
return model
def Attention_ResUNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
'''
Rsidual UNet, with attention
'''
# network structure
FILTER_NUM = 64 # number of basic filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters
# input data
# dimension of the image depth
inputs = layers.Input(input_shape, dtype=tf.float32)
axis = 3
# Downsampling layers
# DownRes 1, double residual convolution + pooling
conv_128 = res_conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
# DownRes 2
conv_64 = res_conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
# DownRes 3
conv_32 = res_conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32)
# DownRes 4
conv_16 = res_conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16)
# DownRes 5, convolution only
conv_8 = res_conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm)
# Upsampling layers
# UpRes 6, attention gated concatenation + upsampling + double residual convolution
gating_16 = gating_signal(conv_8, 8*FILTER_NUM, batch_norm)
att_16 = attention_block(conv_16, gating_16, 8*FILTER_NUM)
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8)
up_16 = layers.concatenate([up_16, att_16], axis=axis)
up_conv_16 = res_conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 7
gating_32 = gating_signal(up_conv_16, 4*FILTER_NUM, batch_norm)
att_32 = attention_block(conv_32, gating_32, 4*FILTER_NUM)
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16)
up_32 = layers.concatenate([up_32, att_32], axis=axis)
up_conv_32 = res_conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 8
gating_64 = gating_signal(up_conv_32, 2*FILTER_NUM, batch_norm)
att_64 = attention_block(conv_64, gating_64, 2*FILTER_NUM)
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32)
up_64 = layers.concatenate([up_64, att_64], axis=axis)
up_conv_64 = res_conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
# UpRes 9
gating_128 = gating_signal(up_conv_64, FILTER_NUM, batch_norm)
att_128 = attention_block(conv_128, gating_128, FILTER_NUM)
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64)
up_128 = layers.concatenate([up_128, att_128], axis=axis)
up_conv_128 = res_conv_block(up_128, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
# 1*1 convolutional layers
conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_128)
conv_final = layers.BatchNormalization(axis=axis)(conv_final)
conv_final = layers.Activation('sigmoid')(conv_final) #Change to softmax for multichannel
# Model integration
model = models.Model(inputs, conv_final, name="AttentionResUNet")
return model
input_shape = (256,256,1)
UNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True)