Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jmwang0117 committed Oct 21, 2024
1 parent 37ab4dc commit 56f8b1a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
2 changes: 1 addition & 1 deletion utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def save(path, model, optimizer, scheduler, epoch, config):
'''

# Remove recursively if epoch_last folder exists and create new one
# _remove_recursively(path)
#_remove_recursively(path)
_create_directory(path)

weights_fpath = os.path.join(path, 'weights_epoch_{}.pth'.format(str(epoch).zfill(3)))
Expand Down
4 changes: 2 additions & 2 deletions utils/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from networks.dsc import DSC
from networks.occrwkv import OccRWKV


def get_model(_cfg, phase='train'):
return DSC(_cfg, phase=phase)
return OccRWKV(_cfg, phase=phase)
29 changes: 13 additions & 16 deletions utils/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,19 @@


def build_optimizer(_cfg, model):

opt = _cfg._dict['OPTIMIZER']['TYPE']
lr = _cfg._dict['OPTIMIZER']['BASE_LR']
if 'MOMENTUM' in _cfg._dict['OPTIMIZER']: momentum = _cfg._dict['OPTIMIZER']['MOMENTUM']
if 'WEIGHT_DECAY' in _cfg._dict['OPTIMIZER']: weight_decay = _cfg._dict['OPTIMIZER']['WEIGHT_DECAY']

if opt == 'Adam': optimizer = optim.Adam(model.get_parameters(),
lr=lr,
betas=(0.9, 0.999))

elif opt == 'SGD': optimizer = optim.SGD(model.get_parameters(),
lr=lr,
momentum=momentum,
weight_decay=weight_decay)

return optimizer
opt = _cfg._dict['OPTIMIZER']['TYPE']
lr = float(_cfg._dict['OPTIMIZER']['BASE_LR'])
momentum = _cfg._dict['OPTIMIZER'].get('MOMENTUM', 0.9) # Default momentum if not specified
weight_decay = _cfg._dict['OPTIMIZER'].get('WEIGHT_DECAY', 0) # Default weight decay if not specified

if opt == 'Adam':
optimizer = optim.Adam(model.get_parameters(), lr=lr, betas=(0.9, 0.999))
elif opt == 'AdamW':
optimizer = optim.AdamW(model.get_parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1e-4)
elif opt == 'SGD':
optimizer = optim.SGD(model.get_parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

return optimizer


def build_scheduler(_cfg, optimizer):
Expand Down

0 comments on commit 56f8b1a

Please sign in to comment.