Skip to content

Commit 79d4f29

Browse files
committed
update main.py to be compatible with new model implementation and pytorch version
The training script(main.py) is updated to work with latest version of pytorch and the new implementation of simplenet.py now instead of the previously `simplenet()` constructor, all variants of simplenet used in the paper are available as distinct model configurations and can be used readily.
1 parent 9c18e1b commit 79d4f29

File tree

1 file changed

+45
-52
lines changed

1 file changed

+45
-52
lines changed

Cifar/main.py

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# original coder : https://github.com/D-X-Y/ResNeXt-DenseNet
2-
# added simpnet model
2+
# added simplenet model
33
from __future__ import division
44

55
import os, sys, pdb, shutil, time, random, datetime
@@ -11,7 +11,7 @@
1111
import torchvision.transforms as transforms
1212
from utils import AverageMeter, RecorderMeter, time_string, convert_secs2time
1313
import models
14-
from tensorboardX import SummaryWriter
14+
from torch.utils.tensorboard import SummaryWriter
1515

1616
model_names = sorted(name for name in models.__dict__
1717
if name.islower() and not name.startswith("__")
@@ -21,7 +21,7 @@
2121
parser = argparse.ArgumentParser(description='Trains ResNeXt on CIFAR or ImageNet', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
2222
parser.add_argument('data_path', type=str, help='Path to dataset')
2323
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'imagenet', 'svhn', 'stl10'], help='Choose between Cifar10/100 and ImageNet.')
24-
parser.add_argument('--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext29_8_64)')
24+
parser.add_argument('--arch', metavar='ARCH', default='simplenet_cifar_5m', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext29_8_64)')
2525
# Optimization options
2626
parser.add_argument('--epochs', type=int, default=700, help='Number of epochs to train.')
2727
parser.add_argument('--batch_size', type=int, default=64, help='Batch size.')
@@ -93,7 +93,6 @@ def main():
9393

9494
writer = SummaryWriter()
9595

96-
9796
# # Data transforms
9897
# mean = [0.5071, 0.4867, 0.4408]
9998
# std = [0.2675, 0.2565, 0.2761]
@@ -129,7 +128,7 @@ def main():
129128

130129
print_log("=> creating model '{}'".format(args.arch), log)
131130
# Init model, criterion, and optimizer
132-
net = models.__dict__[args.arch](num_classes)
131+
net = models.__dict__[args.arch](num_classes=num_classes)
133132
#torch.save(net, 'net.pth')
134133
#init_net = torch.load('net.pth')
135134
#net.load_my_state_dict(init_net.state_dict())
@@ -187,14 +186,9 @@ def main():
187186

188187
for epoch in range(args.start_epoch, args.epochs):
189188
#current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
190-
current_learning_rate = float(scheduler.get_lr()[-1])
191-
#print('lr:',current_learning_rate)
192-
193-
scheduler.step()
194-
195-
#adjust_learning_rate(optimizer, epoch)
196-
197-
189+
current_learning_rate = float(scheduler.get_last_lr()[-1])
190+
# print('lr:',current_learning_rate)
191+
198192
need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
199193
need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
200194

@@ -204,6 +198,9 @@ def main():
204198
# train for one epoch
205199
train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)
206200

201+
scheduler.step()
202+
#adjust_learning_rate(optimizer, epoch)
203+
207204
# evaluate on validation set
208205
#val_acc, val_los = extract_features(test_loader, net, criterion, log)
209206
val_acc, val_los = validate(test_loader, net, criterion, log)
@@ -250,7 +247,7 @@ def train(train_loader, model, criterion, optimizer, epoch, log):
250247
data_time.update(time.time() - end)
251248

252249
if args.use_cuda:
253-
target = target.cuda(async=True)
250+
target = target.cuda()
254251
input = input.cuda()
255252
input_var = torch.autograd.Variable(input)
256253
target_var = torch.autograd.Variable(target)
@@ -261,9 +258,9 @@ def train(train_loader, model, criterion, optimizer, epoch, log):
261258

262259
# measure accuracy and record loss
263260
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
264-
losses.update(loss.data[0], input.size(0))
265-
top1.update(prec1[0], input.size(0))
266-
top5.update(prec5[0], input.size(0))
261+
losses.update(loss.item(), input.size(0))
262+
top1.update(prec1.item(), input.size(0))
263+
top5.update(prec5.item(), input.size(0))
267264

268265
# compute gradient and do SGD step
269266
optimizer.zero_grad()
@@ -293,23 +290,21 @@ def validate(val_loader, model, criterion, log):
293290

294291
# switch to evaluate mode
295292
model.eval()
296-
297-
for i, (input, target) in enumerate(val_loader):
298-
if args.use_cuda:
299-
target = target.cuda(async=True)
300-
input = input.cuda()
301-
input_var = torch.autograd.Variable(input, volatile=True)
302-
target_var = torch.autograd.Variable(target, volatile=True)
303-
304-
# compute output
305-
output = model(input_var)
306-
loss = criterion(output, target_var)
307-
308-
# measure accuracy and record loss
309-
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
310-
losses.update(loss.data[0], input.size(0))
311-
top1.update(prec1[0], input.size(0))
312-
top5.update(prec5[0], input.size(0))
293+
with torch.no_grad():
294+
for i, (input, target) in enumerate(val_loader):
295+
if args.use_cuda:
296+
target = target.cuda()
297+
input = input.cuda()
298+
299+
# compute output
300+
output = model(input)
301+
loss = criterion(output, target)
302+
303+
# measure accuracy and record loss
304+
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
305+
losses.update(loss.data.item(), input.size(0))
306+
top1.update(prec1.item(), input.size(0))
307+
top5.update(prec5.item(), input.size(0))
313308

314309
print_log(' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log)
315310

@@ -322,26 +317,24 @@ def extract_features(val_loader, model, criterion, log):
322317

323318
# switch to evaluate mode
324319
model.eval()
320+
with torch.no_grad():
321+
for i, (input, target) in enumerate(val_loader):
322+
if args.use_cuda:
323+
target = target.cuda()
324+
input = input.cuda()
325325

326-
for i, (input, target) in enumerate(val_loader):
327-
if args.use_cuda:
328-
target = target.cuda(async=True)
329-
input = input.cuda()
330-
input_var = torch.autograd.Variable(input, volatile=True)
331-
target_var = torch.autograd.Variable(target, volatile=True)
332-
333-
# compute output
334-
output, features = model([input_var])
326+
# compute output
327+
output, features = model([input])
335328

336-
pdb.set_trace()
329+
pdb.set_trace()
337330

338-
loss = criterion(output, target_var)
331+
loss = criterion(output, target)
339332

340-
# measure accuracy and record loss
341-
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
342-
losses.update(loss.data[0], input.size(0))
343-
top1.update(prec1[0], input.size(0))
344-
top5.update(prec5[0], input.size(0))
333+
# measure accuracy and record loss
334+
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
335+
losses.update(loss.data.item(), input.size(0))
336+
top1.update(prec1.item(), input.size(0))
337+
top5.update(prec5.item(), input.size(0))
345338

346339
print_log(' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log)
347340

@@ -389,11 +382,11 @@ def accuracy(output, target, topk=(1,)):
389382

390383
_, pred = output.topk(maxk, 1, True, True)
391384
pred = pred.t()
392-
correct = pred.eq(target.view(1, -1).expand_as(pred))
385+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
393386

394387
res = []
395388
for k in topk:
396-
correct_k = correct[:k].view(-1).float().sum(0)
389+
correct_k = correct[:k].reshape(-1).float().sum(0)
397390
res.append(correct_k.mul_(100.0 / batch_size))
398391
return res
399392

0 commit comments

Comments
 (0)