-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrunmanager.py
63 lines (46 loc) · 1.85 KB
/
runmanager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Run Manger to manage
# 1. Every Epoch
# 2. Experiment with Runs - to be coded
import torch
import time
class RunManager:
def __init__(self, network, loader, summaryWriter):
self.epoch_id = -1
self.epoch_loss = 0
self.best_epoch_loss = 1e20
self.best_epoch_id = -1
self.epoch_start_time = None
self.network = network
self.loader = loader
self.tb = summaryWriter
#self.best_network = None
def begin_epoch(self):
self.epoch_start_time = time.time()
self.epoch_id += 1
self.epoch_loss = 0
def end_epoch(self):
epoch_duration = time.time() - self.epoch_start_time
self.epoch_loss = self.epoch_loss / (len(self.loader.dataset) / self.loader.batch_size)
print(f"Epoch: {self.epoch_id} Loss: {self.epoch_loss}")
# save model with the best loss
if self.epoch_loss < self.best_epoch_loss:
self.best_epoch_loss = self.epoch_loss
self.best_epoch_id = self.epoch_id
#self.best_network = self.network.state_dict()
torch.save(self.network.state_dict(), 'saved_models/autoencoder_%d.pth' % (self.best_epoch_id))
# Tensorboard logging
# add graph for loss
self.tb.add_scalar('Training Loss', self.epoch_loss, self.epoch_id)
# add graph for accuracy
# add historgram for weights and biases
for name, param in self.network.named_parameters():
if('bn' not in name and 'stn' not in name):
self.tb.add_histogram(name, param, self.epoch_id)
if param.grad is not None:
self.tb.add_histogram(name + "_grad", param.grad, self.epoch_id)
def track_loss(self, loss):
self.epoch_loss += loss.item()
def being_batch():
pass
def end_batch():
pass