forked from salwaalkhatib/pytorch-intro
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
60 lines (51 loc) · 2.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from tqdm import tqdm
from torchvision import models
MODELS = {
"vgg16": models.vgg16(weights='VGG16_Weights.DEFAULT'),
"vgg19": models.vgg19(weights='VGG19_Weights.DEFAULT'),
"resnet": models.resnet50(weights='ResNet50_Weights.DEFAULT'),
"mobilenet": models.mobilenet_v2(weights='MobileNet_V2_Weights.DEFAULT'),
}
def save_checkpoint(model, optimizer, epoch, path, best = False):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), # why save optimizer state dictionary?
}
torch.save(checkpoint, path)
if(best):
print("[INFO] best checkpoint saved at epoch {}".format(epoch))
else:
print("[INFO] checkpoint saved at epoch {}".format(epoch))
def load_checkpoint(model, path):
ckpt = torch.load(path)
model.load_state_dict(ckpt['model_state_dict'])
print("[INFO] checkpoint loaded")
return model
def train(model, optimizer, criterion, train_loader, DEVICE):
model.train() # Set the model to training mode
running_loss = 0.0
for inputs, labels in tqdm(train_loader):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) # move data to device
optimizer.zero_grad() # Zero the gradients
# Forward pass
outputs = # TODO: FEED INPUTS TO THE MODEL
loss = criterion(, ) # TODO: WHAT DO YOU FEED?
# Backpropagation and optimization
loss.backward()
optimizer.step()
running_loss += loss.item()
return running_loss / len(train_loader)
def validate(model, val_loader, DEVICE):
# TODO: Set the model to evaluation mode
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct/total * 100