Skip to content

Commit 5b376a6

Browse files
committed
training model
1 parent 4be4aba commit 5b376a6

File tree

1 file changed

+276
-0
lines changed

1 file changed

+276
-0
lines changed

main.py

+276
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import os
2+
import shutil
3+
import sys
4+
import argparse
5+
import time
6+
import itertools
7+
8+
import numpy as np
9+
import torch
10+
import torch.nn as nn
11+
import warnings
12+
import matplotlib.pyplot as plt
13+
import torch.optim as optim
14+
import torch.nn.functional as F
15+
from sklearn.metrics import confusion_matrix
16+
import scikitplot as skplt
17+
from torch.autograd import Variable
18+
from torch.backends import cudnn
19+
from torch.nn import DataParallel
20+
import torchvision.transforms as transforms
21+
import torchvision.models as models
22+
from torch.optim import lr_scheduler
23+
from torch.utils.data import DataLoader
24+
from torchvision.datasets import ImageFolder
25+
26+
sys.path.append('./')
27+
from utils.util import set_prefix, write, add_prefix
28+
from utils.FocalLoss import FocalLoss
29+
30+
plt.switch_backend('agg')
31+
32+
parser = argparse.ArgumentParser(description='Training on Diabetic Retinopathy Dataset')
33+
parser.add_argument('--batch_size', '-b', default=90, type=int, help='batch size')
34+
parser.add_argument('--epochs', '-e', default=90, type=int, help='training epochs')
35+
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
36+
parser.add_argument('--cuda', default=torch.cuda.is_available(), type=bool, help='use gpu or not')
37+
parser.add_argument('--step_size', default=30, type=int, help='learning rate decay interval')
38+
parser.add_argument('--gamma', default=0.1, type=float, help='learning rate decay scope')
39+
parser.add_argument('--interval_freq', '-i', default=12, type=int, help='printing log frequence')
40+
parser.add_argument('--data', '-d', default='./data/data_augu', help='path to dataset')
41+
parser.add_argument('--prefix', '-p', default='classifier', type=str, help='folder prefix')
42+
parser.add_argument('--best_model_path', default='model_best.pth.tar', help='best model saved path')
43+
parser.add_argument('--is_focal_loss', '-f', action='store_false',
44+
help='use focal loss or common loss(i.e. cross ectropy loss)(default: true)')
45+
46+
best_acc = 0.0
47+
48+
49+
def main():
50+
global args, best_acc
51+
args = parser.parse_args()
52+
# save source script
53+
set_prefix(args.prefix, __file__)
54+
model = models.densenet121(pretrained=False, num_classes=2)
55+
if args.cuda:
56+
model = DataParallel(model).cuda()
57+
else:
58+
warnings.warn('there is no gpu')
59+
60+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
61+
# accelerate the speed of training
62+
cudnn.benchmark = True
63+
64+
train_loader, val_loader = load_dataset()
65+
# class_names=['LESION', 'NORMAL']
66+
class_names = train_loader.dataset.classes
67+
print(class_names)
68+
if args.is_focal_loss:
69+
print('try focal loss!!')
70+
criterion = FocalLoss().cuda()
71+
else:
72+
criterion = nn.CrossEntropyLoss().cuda()
73+
74+
# learning rate decay per epochs
75+
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
76+
since = time.time()
77+
print('-' * 10)
78+
for epoch in range(args.epochs):
79+
exp_lr_scheduler.step()
80+
train(train_loader, model, optimizer, criterion, epoch)
81+
cur_accuracy = validate(model, val_loader, criterion)
82+
is_best = cur_accuracy > best_acc
83+
best_acc = max(cur_accuracy, best_acc)
84+
save_checkpoint({
85+
'epoch': epoch + 1,
86+
'arch': 'resnet18',
87+
'state_dict': model.state_dict(),
88+
'best_accuracy': best_acc,
89+
'optimizer': optimizer.state_dict(),
90+
}, is_best)
91+
time_elapsed = time.time() - since
92+
print('Training complete in {:.0f}m {:.0f}s'.format(
93+
time_elapsed // 60, time_elapsed % 60))
94+
# compute validate meter such as confusion matrix
95+
compute_validate_meter(model, add_prefix(args.prefix, args.best_model_path), val_loader)
96+
# save running parameter setting to json
97+
write(vars(args), add_prefix(args.prefix, 'paras.txt'))
98+
99+
100+
def compute_validate_meter(model, best_model_path, val_loader):
101+
checkpoint = torch.load(best_model_path)
102+
model.load_state_dict(checkpoint['state_dict'])
103+
best_acc = checkpoint['best_accuracy']
104+
print('best accuracy={:.4f}'.format(best_acc))
105+
pred_y = list()
106+
test_y = list()
107+
probas_y = list()
108+
for data, target in val_loader:
109+
if args.cuda:
110+
data, target = data.cuda(), target.cuda()
111+
data, target = Variable(data, volatile=True), Variable(target)
112+
output = model(data)
113+
probas_y.extend(output.data.cpu().numpy().tolist())
114+
pred_y.extend(output.data.cpu().max(1, keepdim=True)[1].numpy().flatten().tolist())
115+
test_y.extend(target.data.cpu().numpy().flatten().tolist())
116+
117+
confusion = confusion_matrix(pred_y, test_y)
118+
plot_confusion_matrix(confusion,
119+
classes=val_loader.dataset.classes,
120+
title='Confusion matrix')
121+
plt_roc(test_y, probas_y)
122+
123+
124+
def plt_roc(test_y, probas_y, plot_micro=False, plot_macro=False):
125+
assert isinstance(test_y, list) and isinstance(probas_y, list), 'the type of input must be list'
126+
skplt.metrics.plot_roc(test_y, probas_y, plot_micro=plot_micro, plot_macro=plot_macro)
127+
plt.savefig(add_prefix(args.prefix, 'roc_auc_curve.png'))
128+
plt.close()
129+
130+
131+
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
132+
"""
133+
This function prints and plots the confusion matrix.
134+
Normalization can be applied by setting `normalize=True`.
135+
refence:
136+
http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
137+
"""
138+
if normalize:
139+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
140+
print("Normalized confusion matrix")
141+
else:
142+
print('Confusion matrix, without normalization')
143+
144+
print(cm)
145+
146+
plt.imshow(cm, interpolation='nearest', cmap=cmap)
147+
plt.title(title)
148+
plt.colorbar()
149+
tick_marks = np.arange(len(classes))
150+
plt.xticks(tick_marks, classes, rotation=45)
151+
plt.yticks(tick_marks, classes)
152+
153+
fmt = '.2f' if normalize else 'd'
154+
thresh = cm.max() / 2.
155+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
156+
plt.text(j, i, format(cm[i, j], fmt),
157+
horizontalalignment="center",
158+
color="white" if cm[i, j] > thresh else "black")
159+
160+
plt.tight_layout()
161+
plt.ylabel('True label')
162+
plt.xlabel('Predicted label')
163+
plt.savefig(add_prefix(args.prefix, 'confusion_matrix.png'))
164+
plt.close()
165+
166+
167+
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
168+
# save training state after each epoch
169+
torch.save(state, add_prefix(args.prefix, filename))
170+
if is_best:
171+
shutil.copyfile(add_prefix(args.prefix, filename),
172+
add_prefix(args.prefix, args.best_model_path))
173+
174+
175+
def load_dataset():
176+
if args.data == './data/data_augu':
177+
traindir = os.path.join(args.data, 'train')
178+
valdir = os.path.join(args.data, 'val')
179+
mean = [0.5186, 0.5186, 0.5186]
180+
std = [0.1968, 0.1968, 0.1968]
181+
normalize = transforms.Normalize(mean, std)
182+
train_transforms = transforms.Compose([
183+
transforms.CenterCrop(224),
184+
transforms.RandomHorizontalFlip(),
185+
transforms.ToTensor(),
186+
normalize,
187+
])
188+
val_transforms = transforms.Compose([
189+
transforms.CenterCrop(224),
190+
transforms.ToTensor(),
191+
normalize,
192+
])
193+
train_dataset = ImageFolder(traindir, train_transforms)
194+
val_dataset = ImageFolder(valdir, val_transforms)
195+
print('load data-augumentation dataset successfully!!!')
196+
else:
197+
raise ValueError("parameter 'data' that means path to dataset must be in "
198+
"['./data/data_augu']")
199+
200+
train_loader = DataLoader(train_dataset,
201+
batch_size=args.batch_size,
202+
shuffle=True,
203+
num_workers=4,
204+
pin_memory=True if args.cuda else False)
205+
val_loader = DataLoader(val_dataset,
206+
batch_size=args.batch_size,
207+
shuffle=False,
208+
num_workers=1,
209+
pin_memory=True if args.cuda else False)
210+
return train_loader, val_loader
211+
212+
213+
def train(train_loader, model, optimizer, criterion, epoch):
214+
model.train(True)
215+
print('Epoch {}/{}'.format(epoch + 1, args.epochs))
216+
print('-' * 10)
217+
running_loss = 0.0
218+
running_corrects = 0
219+
220+
# Iterate over data.
221+
for idx, (inputs, labels) in enumerate(train_loader):
222+
# wrap them in Variable
223+
if args.cuda:
224+
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
225+
else:
226+
inputs, labels = Variable(inputs), Variable(labels)
227+
228+
# zero the parameter gradients
229+
optimizer.zero_grad()
230+
231+
# forward
232+
outputs = model(inputs)
233+
234+
_, preds = torch.max(outputs.data, 1)
235+
236+
loss = criterion(outputs, labels)
237+
loss.backward()
238+
optimizer.step()
239+
if idx % args.interval_freq == 0:
240+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
241+
epoch + 1, idx * len(inputs), len(train_loader.dataset),
242+
100. * idx / len(train_loader), loss.data[0]))
243+
244+
# statistics
245+
running_loss += loss.data[0] * inputs.size(0)
246+
running_corrects += torch.sum(preds == labels.data)
247+
248+
epoch_loss = running_loss / len(train_loader.dataset)
249+
epoch_acc = running_corrects / len(train_loader.dataset)
250+
251+
print('Training Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
252+
253+
254+
def validate(model, val_loader, criterion):
255+
model.eval()
256+
test_loss = 0
257+
correct = 0
258+
for data, target in val_loader:
259+
if args.cuda:
260+
data, target = data.cuda(), target.cuda()
261+
data, target = Variable(data, volatile=True), Variable(target)
262+
output = model(data)
263+
test_loss += criterion(output, target).data[0]
264+
# get the index of the max log-probability
265+
pred = output.data.max(1, keepdim=True)[1]
266+
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
267+
268+
test_loss /= len(val_loader.dataset)
269+
test_acc = 100. * correct / len(val_loader.dataset)
270+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
271+
test_loss, correct, len(val_loader.dataset), test_acc))
272+
return test_acc
273+
274+
275+
if __name__ == '__main__':
276+
main()

0 commit comments

Comments
 (0)