Skip to content

Commit 9f6450e

Browse files
committed
bug fix. Train and test runs
1 parent 3d2c861 commit 9f6450e

File tree

3 files changed

+14
-66
lines changed

3 files changed

+14
-66
lines changed

logHelper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def setup_logger(logger_name, output_file=None):
1313
formatter = logging.Formatter('[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
1414

1515
# Always log to the default "logging.txt" file
16-
default_fh = logging.FileHandler(global_logging_txt)
16+
# file_handler = logging.FileHandler(filename='your_log_file.log', encoding='utf-8')
17+
default_fh = logging.FileHandler(filename=global_logging_txt, encoding='utf-8')
1718
default_fh.setLevel(logging.INFO)
1819
default_fh.setFormatter(formatter)
1920
logger.addHandler(default_fh)

train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ def __len__(self):
125125
return size
126126

127127

128+
def format_progress_message(iteration, total):
129+
130+
progress_percentage = (iteration / total) * 100
131+
progress_bar_length = int(50 * iteration // total)
132+
progress_bar = '#' * progress_bar_length + '-' * (50 - progress_bar_length)
133+
134+
return f"{progress_percentage:3.0f}%|{progress_bar}| {iteration}/{total}"
135+
136+
128137
# if __name__ == "__main__":
129138
def train(batch_size=4, max_epochs=200, base_lr=0.01, seed=1234, n_gpu=1, img_size=224, model_name='swinunet'):
130139
if os.environ.get('current_fold') is None:
@@ -282,7 +291,8 @@ def train(batch_size=4, max_epochs=200, base_lr=0.01, seed=1234, n_gpu=1, img_si
282291
param_group['lr'] = lr_
283292

284293
iter_num = iter_num + 1
285-
294+
progress_message = format_progress_message(epoch_num + 1, max_epoch)
295+
logger.info(progress_message)
286296
# logger.info('iteration %d : loss : %f' % (iter_num, loss.item()))
287297

288298
save_interval = 5

web.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def run_test():
575575
# TODO convert it from command line to python code
576576
# threading.Thread(target=run_command_async, args=(complete_command,)).start()
577577

578-
test_process = Process(target=run_command_async, args="test")
578+
test_process = Process(target=run_command_async, args=("test", 4, 2, 0.01))
579579
test_process.start()
580580
test_process.join()
581581
print("Test process started")
@@ -598,69 +598,6 @@ def run_test():
598598
data_json = json.load(f)
599599
label_num = len(data_json['labels'])
600600

601-
# if os.environ['MODEL_NAME'] == 'nnunet3d':
602-
# if file_ext in ['.gz', '.nrrd', '.mha', '.nii']:
603-
# img = sitk.ReadImage(img_path)
604-
# img = sitk.GetArrayFromImage(img)
605-
#
606-
# pred = sitk.ReadImage(prediction_path)
607-
# pred = sitk.GetArrayFromImage(pred)
608-
# ground_truth = sitk.ReadImage(ground_truth_path)
609-
# ground_truth = sitk.GetArrayFromImage(ground_truth)
610-
# else:
611-
# return jsonify({'status': 'Please use png, bmp, tif, nii.gz, nrrd or mha format.'})
612-
#
613-
# half_layer = int(np.argmax(ground_truth.sum(axis=(1, 2))))
614-
# img = img[half_layer]
615-
# img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
616-
#
617-
# mask_result = np.zeros_like(img)
618-
# img_with_GT = img.copy()
619-
#
620-
# each_metric = []
621-
# for i in range(1, label_num):
622-
# each_metric.append(calculate_metric_percase(pred == i, ground_truth == i))
623-
#
624-
# mask = np.where(pred[half_layer] == i, 1, 0).astype(np.uint8)
625-
# colored_mask = np.zeros_like(img)
626-
# colored_mask[mask == 1] = colors[i - 1]
627-
# img = cv2.addWeighted(img, 1, colored_mask, 0.5, 0)
628-
#
629-
# gt_mask = np.where(ground_truth[half_layer] == i, 1, 0).astype(np.uint8)
630-
# colored_mask_gt = np.zeros_like(img_with_GT)
631-
# colored_mask_gt[gt_mask == 1] = colors[i - 1]
632-
# img_with_GT = cv2.addWeighted(img_with_GT, 1, colored_mask_gt, 0.5, 0)
633-
#
634-
# mask_result[mask == 1] = colors[i - 1]
635-
#
636-
# metric_list.append(each_metric)
637-
#
638-
# img_with_mask_save_path = os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'],
639-
# os.environ['current_dataset'], nnUNetPlans,
640-
# 'visualization_result', test_img_name + '.png')
641-
# os.makedirs(
642-
# os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'], os.environ['current_dataset'],
643-
# nnUNetPlans, 'visualization_result'), exist_ok=True)
644-
# cv2.imwrite(img_with_mask_save_path, img)
645-
#
646-
# img_with_GT_save_path = os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'],
647-
# os.environ['current_dataset'], nnUNetPlans, 'GT_result',
648-
# test_img_name + '.png')
649-
# os.makedirs(
650-
# os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'], os.environ['current_dataset'],
651-
# nnUNetPlans, 'GT_result'), exist_ok=True)
652-
# cv2.imwrite(img_with_GT_save_path, img_with_GT)
653-
#
654-
# mask_save_path = os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'],
655-
# os.environ['current_dataset'], nnUNetPlans, 'mask_result',
656-
# test_img_name + '.png')
657-
# os.makedirs(
658-
# os.path.join(os.environ['nnUNet_results'], os.environ['MODEL_NAME'], os.environ['current_dataset'],
659-
# nnUNetPlans, 'mask_result'), exist_ok=True)
660-
# cv2.imwrite(mask_save_path, mask_result)
661-
662-
# else:
663-
664601
if file_ext in ['.png', '.bmp', '.tif']:
665602

666603
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)

0 commit comments

Comments
 (0)