-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_teacher.py
101 lines (75 loc) · 3.59 KB
/
train_teacher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from torch.nn.modules.loss import CrossEntropyLoss
from models.teacher_mnist import TeacherNetMnist
from dataloader import create_dataloaders_mnist
from utils import count_parameters, create_parser_train_teacher
from TorchUtils.training.EarlyStopping import EarlyStopping
from TorchUtils.training.StatsTracker import StatsTracker
import torch
from torch.nn.functional import softmax
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
import os
from visualization.plot_train_graph import plot_train_graph
def train_model(save, save_dir, net, lr, epochs, train_loader, val_loader, device, batch_size=32):
optimizer = Adam(params=net.parameters(), lr=lr)
statsTracker = StatsTracker()
earlyStopping = EarlyStopping(patience=8, delta=0.0)
scheduler = ReduceLROnPlateau(
optimizer, 'min', patience=5, eps=0.0003, verbose=True)
for epoch in range(1, epochs + 1):
statsTracker.reset()
net.train()
for x, labels in tqdm.tqdm(train_loader):
x, labels = x.to(device=device), labels.to(device=device)
outputs = net(x)
loss = CrossEntropyLoss(reduction="mean")(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
statsTracker.update_curr_losses(loss.item(), None)
correct = 0
total = 0
with torch.no_grad():
net.eval()
for val_x, val_labels in tqdm.tqdm(val_loader):
val_x, val_labels = val_x.to(
device=device), val_labels.to(device=device)
val_outputs = net(val_x)
val_loss = CrossEntropyLoss(
reduction="mean")(val_outputs, val_labels)
statsTracker.update_curr_losses(None, val_loss.item())
matching = torch.eq(torch.argmax(
softmax(val_outputs, dim=1), dim=1), val_labels)
correct += torch.sum(matching, dim=0).item()
total += val_x.shape[0]
train_loss_epoch = statsTracker.train_loss_curr / \
(batch_size * len(train_loader))
val_loss_epoch = statsTracker.val_loss_curr / (total)
val_accuracy = correct / (total)
statsTracker.update_histories(train_loss_epoch, None)
statsTracker.update_histories(None, val_loss_epoch, net)
print("correct: " + str(correct) + " out of: " + str(total))
print('Teacher_network: Epoch {}, Train Loss {}, Val Loss {}, Val Accuracy {}'.format(
epoch, round(train_loss_epoch, 5), round(val_loss_epoch, 5), round(val_accuracy, 5)))
scheduler.step(val_loss_epoch)
earlyStopping(val_loss_epoch)
if earlyStopping.stop:
break
if save:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(statsTracker.best_model, os.path.join(
save_dir, 'Teacher_network_val_loss{}'.format(round(val_loss_epoch, 5))))
return statsTracker.train_hist, statsTracker.val_hist
if __name__ == "__main__":
parser = create_parser_train_teacher()
args = parser.parse_args()
device = (torch.device('cuda') if torch.cuda.is_available()
else torch.device('cpu'))
print(f"Training on device {device}.")
train_dataset, val_dataset = create_dataloaders_mnist()
net = TeacherNetMnist().to(device=device)
train_hist, val_hist = train_model(args.save, args.save_dir, net, args.lr,
args.epochs, train_dataset, val_dataset, device)
plot_train_graph(train_hist, val_hist, count_parameters(net))