Skip to content

Commit 19e3c60

Browse files
authored
Add files via upload
1 parent 05a93ad commit 19e3c60

File tree

8 files changed

+722
-596
lines changed

8 files changed

+722
-596
lines changed

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
The current code is designed for Windows, not for Linux.
1212

1313

14-
If you only want to use nnSAM, please install [this](https://github.com/Kent0n-Li/nnSAM). <br>
15-
如果你只想运行nnSAM,请访问该代码仓库[this](https://github.com/Kent0n-Li/nnSAM)
16-
17-
1814
Install (安装步骤):
1915

2016
```bash
@@ -35,7 +31,7 @@ pip install timm
3531
pip install git+https://github.com/Kent0n-Li/nnSAM.git
3632

3733
git clone https://github.com/Kent0n-Li/Medical-Image-Segmentation.git
38-
cd Medical-Image-Segmentation
34+
cd Medical-Image-Segmentation-Benchmark
3935
pip install -r requirements.txt
4036
```
4137

@@ -44,6 +40,8 @@ pip install -r requirements.txt
4440
python web.py
4541
```
4642

43+
If you only want to use nnSAM, please install [this](https://github.com/Kent0n-Li/nnSAM). <br>
44+
如果你只想运行nnSAM,请访问该代码仓库[this](https://github.com/Kent0n-Li/nnSAM)
4745

4846
样例数据集:[Demo Dataset](https://github.com/Kent0n-Li/Medical-Image-Segmentation/tree/main/Demo_dataset)
4947

config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,8 @@ def parse_args(type=None):
7979
default=4, help='Batch Size')
8080

8181
return parser.parse_args()
82+
83+
84+
def print_web(text):
85+
with open(output_file, 'a') as f:
86+
f.write(text + '\n')

data_process.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import os
2+
from config import *
3+
from web import *
4+
5+
6+
if __name__ == '__main__':
7+
# Define target paths
8+
target_paths = {
9+
'training_image': os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'imagesTr'),
10+
'training_label': os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'labelsTr'),
11+
'testing_image': os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'imagesTs'),
12+
'testing_label': os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'labelsTs'),
13+
'validation_image': os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'imagesVal'),
14+
'validation_label': os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'labelsVal'),
15+
}
16+
try:
17+
18+
#jpg to png
19+
for path_each in target_paths.values():
20+
convert_jpg_to_png_all_from_path(path_each)
21+
22+
23+
unique_values = find_unique_labels(target_paths['training_label'])
24+
25+
convert_label_by_searchsorted(target_paths['training_label'], unique_values)
26+
convert_label_by_searchsorted(target_paths['validation_label'], unique_values)
27+
convert_label_by_searchsorted(target_paths['testing_label'], unique_values)
28+
29+
# Generate dataset.json
30+
print_web(f"Generating dataset.json for {os.environ['current_dataset']}")
31+
npimg_path = os.path.join(target_paths['training_image'], os.listdir(target_paths['training_image'])[0])
32+
npimg = cv2.imread(npimg_path, cv2.IMREAD_UNCHANGED)
33+
34+
img_channel = 3 if len(npimg.shape) == 3 else 1
35+
label_class_num = len(unique_values)
36+
37+
dataset_id = len(os.listdir(os.environ['medseg_raw'])) + 1
38+
dataset_id = "{:03}".format(dataset_id)
39+
40+
image_size = npimg.shape[0]
41+
42+
save_dataset_json(dataset_id, os.environ['current_dataset'], image_size, img_channel, label_class_num)
43+
44+
except Exception as e:
45+
print_web(f"Error: {e}")
46+
raise e

jpg_to_png.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
from PIL import Image
3+
import numpy as np
4+
5+
folder_path_list = ['E:\\Demo_dataset\\Kvasir-SEG\\train\\image','E:\\Demo_dataset\\Kvasir-SEG\\test\\image']
6+
for folder_path in folder_path_list:
7+
for file in os.listdir(folder_path):
8+
if file.endswith('.jpg'):
9+
img = Image.open(os.path.join(folder_path, file))
10+
folder_path_new = folder_path+'_png'
11+
os.makedirs(folder_path_new, exist_ok=True)
12+
img.save(os.path.join(folder_path_new, file.replace('.jpg', '.png')))
13+
print(f'{file} converted to png')
14+
15+
16+
17+
folder_path_list = ['E:\\Demo_dataset\\Kvasir-SEG\\train\\label','E:\\Demo_dataset\\Kvasir-SEG\\test\\label']
18+
19+
for folder_path in folder_path_list:
20+
for file in os.listdir(folder_path):
21+
if file.endswith('.jpg'):
22+
img = Image.open(os.path.join(folder_path, file))
23+
# threshold the image
24+
img = np.array(img)
25+
img[img > 100] = 255
26+
img[img <= 100] = 0
27+
img = Image.fromarray(img)
28+
29+
img = img.convert('L')
30+
folder_path_new = folder_path+'_png'
31+
os.makedirs(folder_path_new, exist_ok=True)
32+
33+
img.save(os.path.join(folder_path_new, file.replace('.jpg', '.png')))
34+
print(f'{file} converted to png')
35+
36+

test.py

Lines changed: 79 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,18 @@
2424
import torch.optim as optim
2525
from networks.swin_config import get_swin_config
2626
import requests
27-
from config import parse_args
2827
import gdown
2928
import matplotlib.pyplot as plt
30-
31-
# parser = argparse.ArgumentParser()
32-
#
33-
# parser.add_argument('--max_iterations', type=int,
34-
# default=30000, help='maximum epoch number to train')
35-
# parser.add_argument('--max_epochs', type=int,
36-
# default=200, help='maximum epoch number to train')
37-
# parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
38-
# parser.add_argument('--deterministic', type=int, default=1,
39-
# help='whether use deterministic training')
40-
# parser.add_argument('--base_lr', type=float, default=0.01,
41-
# help='segmentation network learning rate')
42-
# parser.add_argument('--img_size', type=int,
43-
# default=224, help='input patch size of network input')
44-
# parser.add_argument('--seed', type=int,
45-
# default=1234, help='random seed')
46-
# parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
47-
# parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
48-
# help='no: no cache, '
49-
# 'full: cache all data, '
50-
# 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
51-
# parser.add_argument('--resume', help='resume from checkpoint')
52-
# parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
53-
# parser.add_argument('--use-checkpoint', action='store_true',
54-
# help="whether to use gradient checkpointing to save memory")
55-
# parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
56-
# help='mixed precision opt level, if O0, no amp is used')
57-
# parser.add_argument('--tag', help='tag of experiment')
58-
# parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
59-
# parser.add_argument('--throughput', action='store_true', help='Test throughput only')
29+
from logHelper import setup_logger
30+
from config import output_file, parse_args
31+
from networks.YourNet import Your_Net
32+
from networks.GT_UNet import GT_U_Net
33+
from networks.model.BiSeNet import BiSeNet
34+
from networks.model.DDRNet import DDRNet
35+
from networks.model.DeeplabV3Plus import Deeplabv3plus_res50
36+
from networks.model.FCN_ResNet import FCN_ResNet
37+
from networks.model.HRNet import HighResolutionNet
38+
from networks.SegNet import SegNet
6039

6140
args = parse_args()
6241

@@ -100,10 +79,7 @@ def download_model(url, destination):
10079

10180

10281
class DynamicDataset(data.Dataset):
103-
def __init__(self, img_path, gt_path, data_end_json, size=None):
104-
105-
with open(data_end_json) as f:
106-
self.file_end = json.load(f)['file_ending']
82+
def __init__(self, img_path, gt_path, size=None):
10783

10884
self.img_name = os.listdir(img_path)
10985
self.size = size
@@ -114,11 +90,13 @@ def __getitem__(self, item):
11490
imagename = self.img_name[item]
11591
img_path = os.path.join(self.img_path, imagename)
11692

117-
if self.file_end in ['.png', '.bmp', '.tif']:
93+
file_end = imagename.split('.')[-1]
94+
95+
if file_end in ['png']:
11896
npimg = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
11997
npimg = np.array(npimg)
12098

121-
elif self.file_end in ['.gz', '.nrrd', '.mha', '.nii.gz', '.nii']:
99+
elif file_end in ['gz', 'nrrd', 'mha', 'nii.gz', 'nii']:
122100
npimg = sitk.ReadImage(img_path)
123101
npimg = sitk.GetArrayFromImage(npimg)
124102

@@ -138,14 +116,14 @@ def __getitem__(self, item):
138116
antialias=None)
139117
npimg = adapt_size(npimg)
140118

141-
return npimg, imagename.replace('_0000', ''), ori_shape
119+
return npimg, imagename, ori_shape
142120

143121
def __len__(self):
144122
size = int(len(self.img_name))
145123
return size
146124

147125

148-
# if __name__ == "__main__":
126+
149127

150128
def test_model():
151129
cudnn.benchmark = False
@@ -155,37 +133,31 @@ def test_model():
155133
torch.manual_seed(args.seed)
156134
torch.cuda.manual_seed(args.seed)
157135

136+
data_json_file = os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'dataset.json')
137+
138+
with open(data_json_file) as f:
139+
json_data = json.load(f)
140+
num_classes = json_data['label_class_num']
141+
in_channels = json_data['img_channel']
142+
args.img_size = json_data['imgae_size']
143+
144+
158145
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
159-
if device.type == 'cuda':
160-
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # bytes to GB
161-
args.batch_size = int(total_memory / 10) * 2
162-
else:
163-
args.batch_size = 2
164146

165-
fold = os.environ['current_fold']
166147

167-
data_json_file = os.path.join(os.environ['nnUNet_raw'], os.environ['current_dataset'], 'dataset.json')
168-
# split_json_path = os.path.join(os.environ['nnUNet_preprocessed'], os.environ['current_dataset'], 'splits_final.json')
169-
# base_json_path = os.path.join(os.environ['nnUNet_preprocessed'], os.environ['current_dataset'])
170-
output_folder_test = os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'],
171-
os.environ['current_dataset'], 'nnUNetTrainer__nnUNetPlans__2d', 'test_pred')
172-
output_folder_5fold = os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'],
173-
os.environ['current_dataset'], 'nnUNetTrainer__nnUNetPlans__2d', f'fold_{fold}')
148+
args.batch_size = 1
149+
150+
151+
output_folder_test = os.path.join(os.environ['medseg_results'], os.environ['current_dataset'], os.environ['MODEL_NAME'], 'test_pred')
174152

175153
os.makedirs(output_folder_test, exist_ok=True)
176-
os.makedirs(output_folder_5fold, exist_ok=True)
154+
177155

178-
imageTr_path = os.path.join(os.environ['nnUNet_raw'], os.environ['current_dataset'], 'imagesTr')
179-
labelTr_path = os.path.join(os.environ['nnUNet_raw'], os.environ['current_dataset'], 'labelsTr')
180-
imageTs_path = os.path.join(os.environ['nnUNet_raw'], os.environ['current_dataset'], 'imagesTs')
181-
labelTs_path = os.path.join(os.environ['nnUNet_raw'], os.environ['current_dataset'], 'labelsTs')
182156

183-
with open(data_json_file) as f:
184-
json_data = json.load(f)
185-
num_classes = len(json_data['labels'])
186-
in_channels = len(json_data['channel_names'])
157+
imageTs_path = os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'imagesTs')
158+
labelTs_path = os.path.join(os.environ['medseg_raw'], os.environ['current_dataset'], 'labelsTs')
159+
weights_path = os.path.join(os.environ['medseg_results'], os.environ['current_dataset'], os.environ['MODEL_NAME'], 'checkpoint_final.pth')
187160

188-
weights_path = os.path.join(output_folder_5fold, 'checkpoint_final.pth')
189161

190162
model_name = os.environ['MODEL_NAME']
191163
if model_name == 'unet':
@@ -211,6 +183,7 @@ def test_model():
211183
elif model_name == 'swinunet':
212184
args.cfg = './networks/swin_tiny_patch4_window7_224_lite.yaml'
213185
args.opts = None
186+
args.img_size = 224
214187
swin_config = get_swin_config(args)
215188
model = SwinUnet(swin_config, img_size=224, num_classes=num_classes).cuda()
216189
url = "https://drive.google.com/uc?id=1TyMf0_uvaxyacMmVzRfqvLLAWSOE2bJR"
@@ -226,37 +199,52 @@ def test_model():
226199
elif model_name == 'r2unet':
227200
model = R2U_Net(in_ch=in_channels, out_ch=num_classes).cuda()
228201

202+
elif model_name == 'gtunet':
203+
model = GT_U_Net(in_ch=in_channels, out_ch=num_classes).to(device)
204+
args.img_size = 256
205+
206+
elif model_name == 'bisenet':
207+
model = BiSeNet(in_ch=in_channels, out_ch=num_classes).to(device)
208+
209+
elif model_name == 'ddrnet':
210+
model = DDRNet(in_ch=in_channels, out_ch=num_classes).to(device)
211+
212+
elif model_name == 'deeplabv3plus':
213+
model = Deeplabv3plus_res50(in_ch=in_channels, out_ch=num_classes).to(device)
214+
215+
elif model_name == 'hrnet':
216+
model = HighResolutionNet(in_ch=in_channels, out_ch=num_classes).to(device)
217+
218+
elif model_name == 'segnet':
219+
model = SegNet(in_ch=in_channels, out_ch=num_classes).to(device)
220+
221+
elif model_name == 'fcnresnet':
222+
model = FCN_ResNet(in_ch=in_channels, out_ch=num_classes).to(device)
223+
224+
elif model_name == 'yournet':
225+
model = Your_Net(in_ch=in_channels, out_ch=num_classes).to(device)
226+
227+
229228
else:
230229
raise NotImplementedError(f"model_name {model_name} not supported")
231230

232231
model.load_state_dict(torch.load(weights_path))
233232

234-
logging.basicConfig(filename="logging.txt", level=logging.INFO,
235-
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
236-
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
237-
logging.info(str(args))
233+
logger = setup_logger("training_logger", output_file=output_file)
234+
logger.info("Process started")
235+
logger.info(str(args))
238236
base_lr = args.base_lr
239-
batch_size = args.batch_size * args.n_gpu
240-
if model_name == 'swinunet' or model_name == 'transunet':
241-
db_test = DynamicDataset(img_path=imageTs_path, gt_path=labelTs_path, data_end_json=data_json_file,
242-
size=args.img_size)
243-
else:
244-
db_test = DynamicDataset(img_path=imageTs_path, gt_path=labelTs_path, data_end_json=data_json_file)
237+
238+
db_test = DynamicDataset(img_path=imageTs_path, gt_path=labelTs_path, size=args.img_size)
245239

246-
with open(data_json_file) as f:
247-
file_end = json.load(f)['file_ending']
240+
248241

249-
testloader = DataLoader(db_test, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
242+
testloader = DataLoader(db_test, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)
250243
if args.n_gpu > 1:
251244
model = nn.DataParallel(model)
252-
model.train()
253-
ce_loss = CrossEntropyLoss()
254-
dice_loss = DiceLoss(num_classes)
255-
optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
256-
257-
best_performance = 0.0
258-
val_dice_scores = []
259-
epoch_numbers = []
245+
model.eval()
246+
247+
260248
for i_batch, (img, img_name, ori_shape) in enumerate(testloader):
261249
image_batch = img
262250
image_batch = image_batch.cuda()
@@ -268,11 +256,16 @@ def test_model():
268256
align_corners=True)
269257
pred = outputs.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
270258
pred = pred.astype(np.uint8)
259+
print(f"Processing {img_name[0]}")
271260

272-
if file_end in ['.png', '.bmp', '.tif']:
261+
file_end = img_name[0].split('.')[-1]
262+
if file_end in ['png', 'bmp', 'tif']:
273263
pred_img = Image.fromarray(pred)
274264
pred_img.save(os.path.join(output_folder_test, img_name[0]))
275265

276-
elif file_end in ['.gz', '.nrrd', '.mha', '.nii.gz', '.nii']:
266+
elif file_end in ['gz', 'nrrd', 'mha', 'nii.gz', 'nii']:
277267
pred_img = sitk.GetImageFromArray(pred)
278-
sitk.WriteImage(pred_img, os.path.join(output_folder_test, img_name[0]))
268+
sitk.WriteImage(pred_img, os.path.join(output_folder_test, img_name[0]))
269+
270+
if __name__ == "__main__":
271+
test_model()

0 commit comments

Comments
 (0)