Skip to content

Commit 599b5a3

Browse files
committed
Minor fixes
1 parent 6ecd72b commit 599b5a3

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Use the model defined in `model.py` to run ImageNet example:
1616
python imagenet.py --dataroot "/path/to/imagenet/"
1717
```
1818

19-
To run continue training from checkpoint
19+
To continue training from checkpoint
2020
```bash
2121
python imagenet.py --dataroot "/path/to/imagenet/" --resume "/path/to/checkpoint/folder"
2222
```

imagenet.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@
5959
parser.add_argument('--seed', type=int, default=None, metavar='S', help='random seed (default: random)')
6060

6161
# Architecture
62-
parser.add_argument('--scaling', type=float, default=1, metavar='SC', help='Scaling of MobileNet (default x1).')
63-
parser.add_argument('--input-size', type=int, default=224, metavar='I',
64-
help='Input size of MobileNet, multiple of 32 (default 224).')
62+
parser.add_argument('--scaling', type=float, default=1, metavar='SC', help='Scaling of ShuffleNet (default x1).')
63+
parser.add_argument('--input-size', type=int, default=224, metavar='I', help='Input size of ShuffleNet.')
6564
parser.add_argument('--c-tag', type=float, default=0.5, help="c' value")
6665
parser.add_argument('--SE', dest='SE', action='store_true', help='Use SE modules')
6766
parser.add_argument('--residual', dest='residual', action='store_true', help='Just residuals')
@@ -181,12 +180,10 @@ def main():
181180

182181
claimed_acc1 = None
183182
claimed_acc5 = None
184-
if args.input_size in claimed_acc_top1:
185-
if args.scaling in claimed_acc_top1[args.input_size]:
186-
claimed_acc1 = claimed_acc_top1[args.input_size][args.scaling]
187-
claimed_acc5 = claimed_acc_top5[args.input_size][args.scaling]
188-
csv_logger.write_text(
189-
'Claimed accuracies are: {:.2f}% top-1, {:.2f}% top-5'.format(claimed_acc1 * 100., claimed_acc5 * 100.))
183+
if args.SE in claimed_acc_top1:
184+
if args.scaling in claimed_acc_top1[args.SE]:
185+
claimed_acc1 = 1 - claimed_acc_top1[args.SE][args.scaling]
186+
csv_logger.write_text('Claimed accuracy is {:.2f}% top-1'.format(claimed_acc1 * 100.))
190187
train_network(args.start_epoch, args.epochs, scheduler, model, train_loader, val_loader, optimizer, criterion,
191188
device, dtype, args.batch_size, args.log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5,
192189
best_test)

0 commit comments

Comments
 (0)