Skip to content

Commit

Permalink
hot fix for MultiStepLR_HotFix
Browse files Browse the repository at this point in the history
  • Loading branch information
ildoonet authored Apr 11, 2020
1 parent 67a2654 commit 2424224
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions FastAutoAugment/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch

from torch.optim.lr_scheduler import MultiStepLR
from theconf import Config as C


Expand All @@ -10,8 +10,14 @@ def adjust_learning_rate_resnet(optimizer):
"""

if C.get()['epoch'] == 90:
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80])
return MultiStepLR_HotFix(optimizer, [30, 60, 80])
elif C.get()['epoch'] == 270: # autoaugment
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240])
return MultiStepLR_HotFix(optimizer, [90, 180, 240])
else:
raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch'])


class MultiStepLR_HotFix(MultiStepLR):
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
super(MultiStepLR_HotFix, self).__init__(optimizer, milestones, gamma, last_epoch)
self.milestones = list(milestones)

0 comments on commit 2424224

Please sign in to comment.