24
24
import torch .optim as optim
25
25
from networks .swin_config import get_swin_config
26
26
import requests
27
- from config import parse_args
28
27
import gdown
29
28
import matplotlib .pyplot as plt
30
-
31
- # parser = argparse.ArgumentParser()
32
- #
33
- # parser.add_argument('--max_iterations', type=int,
34
- # default=30000, help='maximum epoch number to train')
35
- # parser.add_argument('--max_epochs', type=int,
36
- # default=200, help='maximum epoch number to train')
37
- # parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
38
- # parser.add_argument('--deterministic', type=int, default=1,
39
- # help='whether use deterministic training')
40
- # parser.add_argument('--base_lr', type=float, default=0.01,
41
- # help='segmentation network learning rate')
42
- # parser.add_argument('--img_size', type=int,
43
- # default=224, help='input patch size of network input')
44
- # parser.add_argument('--seed', type=int,
45
- # default=1234, help='random seed')
46
- # parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
47
- # parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
48
- # help='no: no cache, '
49
- # 'full: cache all data, '
50
- # 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
51
- # parser.add_argument('--resume', help='resume from checkpoint')
52
- # parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
53
- # parser.add_argument('--use-checkpoint', action='store_true',
54
- # help="whether to use gradient checkpointing to save memory")
55
- # parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
56
- # help='mixed precision opt level, if O0, no amp is used')
57
- # parser.add_argument('--tag', help='tag of experiment')
58
- # parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
59
- # parser.add_argument('--throughput', action='store_true', help='Test throughput only')
29
+ from logHelper import setup_logger
30
+ from config import output_file , parse_args
31
+ from networks .YourNet import Your_Net
32
+ from networks .GT_UNet import GT_U_Net
33
+ from networks .model .BiSeNet import BiSeNet
34
+ from networks .model .DDRNet import DDRNet
35
+ from networks .model .DeeplabV3Plus import Deeplabv3plus_res50
36
+ from networks .model .FCN_ResNet import FCN_ResNet
37
+ from networks .model .HRNet import HighResolutionNet
38
+ from networks .SegNet import SegNet
60
39
61
40
args = parse_args ()
62
41
@@ -100,10 +79,7 @@ def download_model(url, destination):
100
79
101
80
102
81
class DynamicDataset (data .Dataset ):
103
- def __init__ (self , img_path , gt_path , data_end_json , size = None ):
104
-
105
- with open (data_end_json ) as f :
106
- self .file_end = json .load (f )['file_ending' ]
82
+ def __init__ (self , img_path , gt_path , size = None ):
107
83
108
84
self .img_name = os .listdir (img_path )
109
85
self .size = size
@@ -114,11 +90,13 @@ def __getitem__(self, item):
114
90
imagename = self .img_name [item ]
115
91
img_path = os .path .join (self .img_path , imagename )
116
92
117
- if self .file_end in ['.png' , '.bmp' , '.tif' ]:
93
+ file_end = imagename .split ('.' )[- 1 ]
94
+
95
+ if file_end in ['png' ]:
118
96
npimg = cv2 .imread (img_path , cv2 .IMREAD_UNCHANGED )
119
97
npimg = np .array (npimg )
120
98
121
- elif self . file_end in ['. gz' , '. nrrd' , '. mha' , '. nii.gz' , '. nii' ]:
99
+ elif file_end in ['gz' , 'nrrd' , 'mha' , 'nii.gz' , 'nii' ]:
122
100
npimg = sitk .ReadImage (img_path )
123
101
npimg = sitk .GetArrayFromImage (npimg )
124
102
@@ -138,14 +116,14 @@ def __getitem__(self, item):
138
116
antialias = None )
139
117
npimg = adapt_size (npimg )
140
118
141
- return npimg , imagename . replace ( '_0000' , '' ) , ori_shape
119
+ return npimg , imagename , ori_shape
142
120
143
121
def __len__ (self ):
144
122
size = int (len (self .img_name ))
145
123
return size
146
124
147
125
148
- # if __name__ == "__main__":
126
+
149
127
150
128
def test_model ():
151
129
cudnn .benchmark = False
@@ -155,37 +133,31 @@ def test_model():
155
133
torch .manual_seed (args .seed )
156
134
torch .cuda .manual_seed (args .seed )
157
135
136
+ data_json_file = os .path .join (os .environ ['medseg_raw' ], os .environ ['current_dataset' ], 'dataset.json' )
137
+
138
+ with open (data_json_file ) as f :
139
+ json_data = json .load (f )
140
+ num_classes = json_data ['label_class_num' ]
141
+ in_channels = json_data ['img_channel' ]
142
+ args .img_size = json_data ['imgae_size' ]
143
+
144
+
158
145
device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
159
- if device .type == 'cuda' :
160
- total_memory = torch .cuda .get_device_properties (device ).total_memory / (1024 ** 3 ) # bytes to GB
161
- args .batch_size = int (total_memory / 10 ) * 2
162
- else :
163
- args .batch_size = 2
164
146
165
- fold = os .environ ['current_fold' ]
166
147
167
- data_json_file = os .path .join (os .environ ['nnUNet_raw' ], os .environ ['current_dataset' ], 'dataset.json' )
168
- # split_json_path = os.path.join(os.environ['nnUNet_preprocessed'], os.environ['current_dataset'], 'splits_final.json')
169
- # base_json_path = os.path.join(os.environ['nnUNet_preprocessed'], os.environ['current_dataset'])
170
- output_folder_test = os .path .join (os .environ ['nnUNet_results' ], os .environ ['MODEL_NAME' ],
171
- os .environ ['current_dataset' ], 'nnUNetTrainer__nnUNetPlans__2d' , 'test_pred' )
172
- output_folder_5fold = os .path .join (os .environ ['nnUNet_results' ], os .environ ['MODEL_NAME' ],
173
- os .environ ['current_dataset' ], 'nnUNetTrainer__nnUNetPlans__2d' , f'fold_{ fold } ' )
148
+ args .batch_size = 1
149
+
150
+
151
+ output_folder_test = os .path .join (os .environ ['medseg_results' ], os .environ ['current_dataset' ], os .environ ['MODEL_NAME' ], 'test_pred' )
174
152
175
153
os .makedirs (output_folder_test , exist_ok = True )
176
- os . makedirs ( output_folder_5fold , exist_ok = True )
154
+
177
155
178
- imageTr_path = os .path .join (os .environ ['nnUNet_raw' ], os .environ ['current_dataset' ], 'imagesTr' )
179
- labelTr_path = os .path .join (os .environ ['nnUNet_raw' ], os .environ ['current_dataset' ], 'labelsTr' )
180
- imageTs_path = os .path .join (os .environ ['nnUNet_raw' ], os .environ ['current_dataset' ], 'imagesTs' )
181
- labelTs_path = os .path .join (os .environ ['nnUNet_raw' ], os .environ ['current_dataset' ], 'labelsTs' )
182
156
183
- with open (data_json_file ) as f :
184
- json_data = json .load (f )
185
- num_classes = len (json_data ['labels' ])
186
- in_channels = len (json_data ['channel_names' ])
157
+ imageTs_path = os .path .join (os .environ ['medseg_raw' ], os .environ ['current_dataset' ], 'imagesTs' )
158
+ labelTs_path = os .path .join (os .environ ['medseg_raw' ], os .environ ['current_dataset' ], 'labelsTs' )
159
+ weights_path = os .path .join (os .environ ['medseg_results' ], os .environ ['current_dataset' ], os .environ ['MODEL_NAME' ], 'checkpoint_final.pth' )
187
160
188
- weights_path = os .path .join (output_folder_5fold , 'checkpoint_final.pth' )
189
161
190
162
model_name = os .environ ['MODEL_NAME' ]
191
163
if model_name == 'unet' :
@@ -211,6 +183,7 @@ def test_model():
211
183
elif model_name == 'swinunet' :
212
184
args .cfg = './networks/swin_tiny_patch4_window7_224_lite.yaml'
213
185
args .opts = None
186
+ args .img_size = 224
214
187
swin_config = get_swin_config (args )
215
188
model = SwinUnet (swin_config , img_size = 224 , num_classes = num_classes ).cuda ()
216
189
url = "https://drive.google.com/uc?id=1TyMf0_uvaxyacMmVzRfqvLLAWSOE2bJR"
@@ -226,37 +199,52 @@ def test_model():
226
199
elif model_name == 'r2unet' :
227
200
model = R2U_Net (in_ch = in_channels , out_ch = num_classes ).cuda ()
228
201
202
+ elif model_name == 'gtunet' :
203
+ model = GT_U_Net (in_ch = in_channels , out_ch = num_classes ).to (device )
204
+ args .img_size = 256
205
+
206
+ elif model_name == 'bisenet' :
207
+ model = BiSeNet (in_ch = in_channels , out_ch = num_classes ).to (device )
208
+
209
+ elif model_name == 'ddrnet' :
210
+ model = DDRNet (in_ch = in_channels , out_ch = num_classes ).to (device )
211
+
212
+ elif model_name == 'deeplabv3plus' :
213
+ model = Deeplabv3plus_res50 (in_ch = in_channels , out_ch = num_classes ).to (device )
214
+
215
+ elif model_name == 'hrnet' :
216
+ model = HighResolutionNet (in_ch = in_channels , out_ch = num_classes ).to (device )
217
+
218
+ elif model_name == 'segnet' :
219
+ model = SegNet (in_ch = in_channels , out_ch = num_classes ).to (device )
220
+
221
+ elif model_name == 'fcnresnet' :
222
+ model = FCN_ResNet (in_ch = in_channels , out_ch = num_classes ).to (device )
223
+
224
+ elif model_name == 'yournet' :
225
+ model = Your_Net (in_ch = in_channels , out_ch = num_classes ).to (device )
226
+
227
+
229
228
else :
230
229
raise NotImplementedError (f"model_name { model_name } not supported" )
231
230
232
231
model .load_state_dict (torch .load (weights_path ))
233
232
234
- logging .basicConfig (filename = "logging.txt" , level = logging .INFO ,
235
- format = '[%(asctime)s.%(msecs)03d] %(message)s' , datefmt = '%H:%M:%S' )
236
- logging .getLogger ().addHandler (logging .StreamHandler (sys .stdout ))
237
- logging .info (str (args ))
233
+ logger = setup_logger ("training_logger" , output_file = output_file )
234
+ logger .info ("Process started" )
235
+ logger .info (str (args ))
238
236
base_lr = args .base_lr
239
- batch_size = args .batch_size * args .n_gpu
240
- if model_name == 'swinunet' or model_name == 'transunet' :
241
- db_test = DynamicDataset (img_path = imageTs_path , gt_path = labelTs_path , data_end_json = data_json_file ,
242
- size = args .img_size )
243
- else :
244
- db_test = DynamicDataset (img_path = imageTs_path , gt_path = labelTs_path , data_end_json = data_json_file )
237
+
238
+ db_test = DynamicDataset (img_path = imageTs_path , gt_path = labelTs_path , size = args .img_size )
245
239
246
- with open (data_json_file ) as f :
247
- file_end = json .load (f )['file_ending' ]
240
+
248
241
249
- testloader = DataLoader (db_test , batch_size = 1 , shuffle = True , num_workers = 2 , pin_memory = True )
242
+ testloader = DataLoader (db_test , batch_size = args . batch_size , shuffle = True , num_workers = 2 , pin_memory = True )
250
243
if args .n_gpu > 1 :
251
244
model = nn .DataParallel (model )
252
- model .train ()
253
- ce_loss = CrossEntropyLoss ()
254
- dice_loss = DiceLoss (num_classes )
255
- optimizer = optim .SGD (model .parameters (), lr = base_lr , momentum = 0.9 , weight_decay = 0.0001 )
256
-
257
- best_performance = 0.0
258
- val_dice_scores = []
259
- epoch_numbers = []
245
+ model .eval ()
246
+
247
+
260
248
for i_batch , (img , img_name , ori_shape ) in enumerate (testloader ):
261
249
image_batch = img
262
250
image_batch = image_batch .cuda ()
@@ -268,11 +256,16 @@ def test_model():
268
256
align_corners = True )
269
257
pred = outputs .data .max (1 )[1 ].squeeze_ (1 ).squeeze_ (0 ).cpu ().numpy ()
270
258
pred = pred .astype (np .uint8 )
259
+ print (f"Processing { img_name [0 ]} " )
271
260
272
- if file_end in ['.png' , '.bmp' , '.tif' ]:
261
+ file_end = img_name [0 ].split ('.' )[- 1 ]
262
+ if file_end in ['png' , 'bmp' , 'tif' ]:
273
263
pred_img = Image .fromarray (pred )
274
264
pred_img .save (os .path .join (output_folder_test , img_name [0 ]))
275
265
276
- elif file_end in ['. gz' , '. nrrd' , '. mha' , '. nii.gz' , '. nii' ]:
266
+ elif file_end in ['gz' , 'nrrd' , 'mha' , 'nii.gz' , 'nii' ]:
277
267
pred_img = sitk .GetImageFromArray (pred )
278
- sitk .WriteImage (pred_img , os .path .join (output_folder_test , img_name [0 ]))
268
+ sitk .WriteImage (pred_img , os .path .join (output_folder_test , img_name [0 ]))
269
+
270
+ if __name__ == "__main__" :
271
+ test_model ()
0 commit comments