-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining_validation_utils.py
144 lines (121 loc) · 6.74 KB
/
training_validation_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
#from re import I
import sys
# Add filepath to sys so that below imports work when this file is called
# from another directory
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from tqdm import trange
import numpy as np
from model_utils import *
from loss_utils import *
from augmentation_utils import DataAugmentation
from segment_metrics import IOU_eval
import time
import glob
from augmentations import rotate
from collections import defaultdict
class Training_Validation():
def __init__(self,power=400,swap_coeffs=False,is_train_data_aug=False,is_apply_color_jitter=False):
super(Training_Validation,self).__init__()
self.power = power
self.swap_coeffs = swap_coeffs
self.is_train_data_aug = is_train_data_aug
self.is_apply_color_jitter = is_apply_color_jitter
def train_valid(self, unet, train_data, optimizer, criterion, batch_size,
dtv, use_gpu, epochs, scheduler=None, metrics = None, validation_data=None,
epoch_lapse = 1, progress_bar=True, checkpoint_path=None, early_stop:int=None,
angle: int=0, base_angles=[0], rotate_crop=False):
x_train, y_train = train_data
if not torch.is_tensor(x_train):
x_train = torch.Tensor(x_train, dtype=torch.float32)
y_train = torch.Tensor(y_train, dtype=torch.float32)
epoch_iter = np.ceil(x_train.shape[0] / batch_size).astype(int)
if use_gpu:
x_train = x_train.cuda()
y_train = y_train.cuda()
if progress_bar:
t = trange(epochs, leave=True)
else:
t = range(epochs)
if checkpoint_path and (not os.path.isdir(checkpoint_path)):
os.makedirs(checkpoint_path)
train_metrics = {'epoch': [], 'loss': []}
val_metrics = {'epoch': [], 'val_loss': []}
for metric in metrics:
train_metrics[metric] = []
val_metrics['val_' + metric] = []
if len(validation_data) > 1:
for i in range(2, len(validation_data) + 1):
for metric in metrics:
val_metrics['val_' + metric + '_' + str(i)] = []
val_metrics['val_loss_' + str(i)] = []
if scheduler is not None:
train_metrics['lr'] = []
# Initialize current checkpoints both as epoch 1 (index 0 in metrics)
ckpt_list = np.array([0, 0]) # Length of this determines how many checkpoints to save. Saves top n checkpoints
for _ in t: # epoch loop
ts = time.time()
metrics_sum = dict((metric, 0) for metric in metrics)
metrics_sum['loss'] = 0
for i in range(epoch_iter): # train step loop
batch_train_x = x_train[i * batch_size : (i + 1) * batch_size]
batch_train_y = y_train[i * batch_size : (i + 1) * batch_size]
#batch_train_x, batch_train_y = rotate(batch_train_x, batch_train_y, angle, base_angles=base_angles, crop=rotate_crop)
batch_metrics = dtv.train_step(batch_train_x , batch_train_y, optimizer, criterion, unet, metrics=metrics)
for key in batch_metrics.keys():
metrics_sum[key] += batch_metrics[key]
te = time.time() - ts
shuffle = torch.randperm(x_train.shape[0])
x_train = x_train[shuffle]
y_train = y_train[shuffle]
epoch_metrics = {}
train_metrics['epoch'].append(_+1)
if scheduler is not None:
train_metrics['lr'].append(scheduler.get_last_lr())
scheduler.step()
for key in metrics_sum.keys():
epoch_metrics[key] = metrics_sum[key] / epoch_iter
train_metrics[key].append(epoch_metrics[key])
if (_+1) % epoch_lapse == 0: # validation
string = 'Epoch %.0f/%.0f -- %.0fs, %.3fs/step - ' % (_+1, epochs, te, te/epoch_iter)
string = string + ', '.join([ ' ' + metric + ': %.4f' % train_metrics[metric][-1] for metric in metrics_sum.keys()])
if validation_data is not None:
val_metrics['epoch'].append(_+1)
for i, val_set in enumerate(validation_data):
x_val, y_val = val_set
if not torch.is_tensor(x_val):
x_val = torch.Tensor(x_val, dtype=torch.float32)
y_val = torch.Tensor(y_val, dtype=torch.float32)
val = dtv.get_val_metrics(x_val, y_val, criterion, unet, metrics)
if i == 0:
for key in val.keys():
val_metrics[key].append(val[key])
string = string + ' - ' + ', '.join([ ' val_' + metric + ': %.4f' % val_metrics['val_'+metric][-1] for metric in metrics_sum.keys()])
else:
for key in val.keys():
val_metrics[key + '_' + str(i + 1)].append(val[key])
# We don't add the metrics from additional val sets to output string bc it makes it too long and unreadable
if checkpoint_path:
# If there are multiple validation sets, uses the first one for checkpointing
v = np.array(val_metrics['val_IoU'])
ind = np.argmin(v[ckpt_list]) # get index of worst checkpoint in ckpt_ind
if v[ckpt_list[ind]] < val_metrics['val_IoU'][-1]:
# remove
for file in glob.glob(os.path.join(checkpoint_path, 'ckpt_epoch_' + str(ckpt_list[ind]+1) + '*')):
os.remove(file)
# save new checkpoint
torch.save(unet, os.path.join(checkpoint_path, 'ckpt_epoch_' + str(_+1) + '.pt'))
string = string + ' - ckpt saved'
# Replace index of worst ckpt with index of new ckpt (epoch - 1)
ckpt_list[ind] = len(v) - 1
print(string, flush=True)
if early_stop is not None:
if all(i > (np.min(val_metrics['val_loss']) + 0.2) for i in val_metrics['val_loss'][-early_stop:]):
print('Stopping early because val loss has been 0.2 greater than lowest val loss for %i consecutive epochs' % early_stop)
return train_metrics, val_metrics, True
return train_metrics, val_metrics, False