Skip to content

Commit 32af730

Browse files
authored
Merge pull request #2511 from huggingface/corrected_weight_decay
Add corrected_weight decay to several optimizers
2 parents 9d294cd + 9ef877a commit 32af730

File tree

11 files changed

+299
-26
lines changed

11 files changed

+299
-26
lines changed

timm/optim/_optim_factory.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,132 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
532532
registry.register(opt)
533533

534534

535+
def _register_corrected_decay_optimizers(registry: OptimizerRegistry) -> None:
536+
"""Register corrected weight decay optimizer variants"""
537+
corrected_optimizers = [
538+
OptimInfo(
539+
name='adamc',
540+
opt_class=AdamWLegacy,
541+
description='AdamW with corrected weight decay (lr²/max_lr scaling)',
542+
has_betas=True,
543+
defaults={'corrected_weight_decay': True}
544+
),
545+
OptimInfo(
546+
name='nadamc',
547+
opt_class=NAdamW,
548+
description='NAdamW with corrected weight decay (lr²/max_lr scaling)',
549+
has_betas=True,
550+
defaults={'corrected_weight_decay': True}
551+
),
552+
OptimInfo(
553+
name='sgdc',
554+
opt_class=SGDW,
555+
description='SGD with corrected decoupled weight decay (lr²/max_lr scaling)',
556+
has_eps=False,
557+
has_momentum=True,
558+
defaults={'nesterov': True, 'corrected_weight_decay': True}
559+
),
560+
OptimInfo(
561+
name='adoptc',
562+
opt_class=Adopt,
563+
description='Adopt with corrected decoupled weight decay (lr²/max_lr scaling)',
564+
defaults={'decoupled': True, 'corrected_weight_decay': True}
565+
),
566+
OptimInfo(
567+
name='lambcd',
568+
opt_class=Lamb,
569+
description='LAMB with corrected decoupled weight decay (lr²/max_lr scaling)',
570+
has_betas=True,
571+
defaults={'decoupled_decay': True, 'corrected_weight_decay': True}
572+
),
573+
OptimInfo(
574+
name='kronc',
575+
opt_class=Kron,
576+
description='PSGD Kron with corrected decoupled weight decay (lr²/max_lr scaling)',
577+
has_momentum=True,
578+
defaults={'decoupled_decay': True, 'corrected_weight_decay': True}
579+
),
580+
OptimInfo(
581+
name='lionc',
582+
opt_class=Lion,
583+
description='Lion with corrected weight decay (lr²/max_lr scaling)',
584+
has_eps=False,
585+
has_betas=True,
586+
defaults={'corrected_weight_decay': True}
587+
),
588+
OptimInfo(
589+
name='lapropc',
590+
opt_class=LaProp,
591+
description='LaProp with corrected weight decay (lr²/max_lr scaling)',
592+
has_betas=True,
593+
defaults={'corrected_weight_decay': True}
594+
),
595+
OptimInfo(
596+
name='rmsproptfc',
597+
opt_class=RMSpropTF,
598+
description='RMSprop TF-style with corrected decoupled weight decay (lr²/max_lr scaling)',
599+
has_momentum=True,
600+
defaults={'alpha': 0.9, 'decoupled_decay': True, 'corrected_weight_decay': True}
601+
),
602+
OptimInfo(
603+
name='adafactorbvc',
604+
opt_class=AdafactorBigVision,
605+
description='Adafactor Big Vision with corrected weight decay (lr²/max_lr or lr/max_lr scaling)',
606+
defaults={'corrected_weight_decay': True}
607+
),
608+
]
609+
for opt in corrected_optimizers:
610+
registry.register(opt)
611+
612+
# Cautious + corrected variants
613+
cautious_corrected = [
614+
OptimInfo(
615+
name='cadamc',
616+
opt_class=AdamWLegacy,
617+
description='Cautious AdamW with corrected weight decay (lr²/max_lr scaling)',
618+
has_betas=True,
619+
defaults={'caution': True, 'corrected_weight_decay': True}
620+
),
621+
OptimInfo(
622+
name='cadoptc',
623+
opt_class=Adopt,
624+
description='Cautious Adopt with corrected decoupled weight decay (lr²/max_lr scaling)',
625+
defaults={'decoupled': True, 'caution': True, 'corrected_weight_decay': True}
626+
),
627+
OptimInfo(
628+
name='cnadamc',
629+
opt_class=NAdamW,
630+
description='Cautious NAdamW with corrected weight decay (lr²/max_lr scaling)',
631+
has_betas=True,
632+
defaults={'caution': True, 'corrected_weight_decay': True}
633+
),
634+
OptimInfo(
635+
name='csgdc',
636+
opt_class=SGDW,
637+
description='Cautious SGD with corrected decoupled weight decay (lr²/max_lr scaling)',
638+
has_eps=False,
639+
has_momentum=True,
640+
defaults={'nesterov': True, 'caution': True, 'corrected_weight_decay': True}
641+
),
642+
OptimInfo(
643+
name='clionc',
644+
opt_class=Lion,
645+
description='Cautious Lion with corrected weight decay (lr²/max_lr scaling)',
646+
has_eps=False,
647+
has_betas=True,
648+
defaults={'caution': True, 'corrected_weight_decay': True}
649+
),
650+
OptimInfo(
651+
name='cadafactorbvc',
652+
opt_class=AdafactorBigVision,
653+
description='Cautious Adafactor Big Vision with corrected weight decay',
654+
defaults={'caution': True, 'corrected_weight_decay': True}
655+
),
656+
]
657+
for opt in cautious_corrected:
658+
registry.register(opt)
659+
660+
535661
def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
536662
cautious_optimizers = [
537663
OptimInfo(
@@ -896,6 +1022,7 @@ def _register_default_optimizers() -> None:
8961022
_register_apex_optimizers(default_registry)
8971023
_register_bnb_optimizers(default_registry)
8981024
_register_cautious_optimizers(default_registry)
1025+
_register_corrected_decay_optimizers(default_registry)
8991026

9001027
# Register aliases
9011028
default_registry.register_alias('nesterov', 'sgd')

timm/optim/adafactor_bv.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
55
Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
66
7+
References for added functionality:
8+
Cautious Optimizers: https://arxiv.org/abs/2411.16085
9+
Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
10+
711
Adaptation and PyTorch modifications by Ross Wightman
812
"""
913
from typing import List, Optional, Tuple, Union
@@ -68,6 +72,7 @@ def __init__(
6872
clipping_threshold: Optional[float] = None,
6973
unscaled_wd: bool = False,
7074
caution: bool = False,
75+
corrected_weight_decay: bool = False,
7176
*,
7277
foreach: Optional[bool] = False,
7378
):
@@ -94,6 +99,7 @@ def __init__(
9499
clipping_threshold=clipping_threshold,
95100
unscaled_wd=unscaled_wd,
96101
caution=caution,
102+
corrected_weight_decay=corrected_weight_decay,
97103
foreach=foreach,
98104
)
99105
super().__init__(params, defaults)
@@ -102,6 +108,7 @@ def __setstate__(self, state):
102108
super().__setstate__(state)
103109
for group in self.param_groups:
104110
group.setdefault('caution', False)
111+
group.setdefault('corrected_weight_decay', False)
105112
group.setdefault('foreach', None)
106113
for p in group['params']:
107114
p_state = self.state.get(p, {})
@@ -197,6 +204,7 @@ def step(self, closure=None):
197204
clipping_threshold=group['clipping_threshold'],
198205
unscaled_wd=group['unscaled_wd'],
199206
caution=group['caution'],
207+
max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
200208
)
201209

202210
return loss
@@ -222,6 +230,7 @@ def _single_tensor_adafactor(
222230
clipping_threshold: Optional[float],
223231
unscaled_wd: bool,
224232
caution: bool,
233+
max_lr: Optional[float],
225234
):
226235
for i, param in enumerate(params):
227236
grad = grads[i]
@@ -286,10 +295,18 @@ def _single_tensor_adafactor(
286295
if weight_decay != 0:
287296
if unscaled_wd:
288297
# match big vision impl, 'fully decoupled' decay w/o LR scaling
289-
param.mul_(1. - weight_decay)
298+
if max_lr is None:
299+
param.mul_(1. - weight_decay)
300+
else:
301+
# corrected weight decay: scale by lr / max_lr
302+
param.mul_(1. - (lr / max_lr) * weight_decay)
290303
else:
291304
# match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
292-
param.mul_(1. - lr * weight_decay)
305+
if max_lr is None:
306+
param.mul_(1. - lr * weight_decay)
307+
else:
308+
# corrected weight decay: scale by lr^2 / max_lr
309+
param.mul_(1. - (lr ** 2 / max_lr) * weight_decay)
293310

294311
# Update parameters
295312
param.add_(update, alpha=-1.0)
@@ -315,6 +332,7 @@ def _multi_tensor_adafactor(
315332
clipping_threshold: Optional[float],
316333
unscaled_wd: bool,
317334
caution: bool,
335+
max_lr: Optional[float],
318336
):
319337
# FIXME TODO
320338
assert False, 'multi-tensor fn (foreach=True) not implemented yet'

timm/optim/adamw.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
""" AdamW Optimizer
22
Impl copied from PyTorch master
33
4+
References for added functionality:
5+
Cautious Optimizers: https://arxiv.org/abs/2411.16085
6+
Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
7+
48
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
59
"""
610
import math
@@ -31,6 +35,7 @@ class AdamWLegacy(Optimizer):
3135
amsgrad: whether to use the AMSGrad variant of this algorithm
3236
from the paper `On the Convergence of Adam and Beyond`
3337
caution: apply caution when using AdamW
38+
corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr)
3439
"""
3540

3641
def __init__(
@@ -42,6 +47,7 @@ def __init__(
4247
weight_decay: float = 1e-2,
4348
amsgrad: bool = False,
4449
caution: bool = False,
50+
corrected_weight_decay: bool = False,
4551
):
4652
if not 0.0 <= lr:
4753
raise ValueError("Invalid learning rate: {}".format(lr))
@@ -58,6 +64,7 @@ def __init__(
5864
weight_decay=weight_decay,
5965
amsgrad=amsgrad,
6066
caution=caution,
67+
corrected_weight_decay=corrected_weight_decay,
6168
)
6269
super(AdamWLegacy, self).__init__(params, defaults)
6370

@@ -66,6 +73,7 @@ def __setstate__(self, state):
6673
for group in self.param_groups:
6774
group.setdefault('amsgrad', False)
6875
group.setdefault('caution', False)
76+
group.setdefault('corrected_weight_decay', False)
6977

7078
@torch.no_grad()
7179
def step(self, closure=None):
@@ -86,7 +94,8 @@ def step(self, closure=None):
8694
continue
8795

8896
# Perform stepweight decay
89-
p.data.mul_(1 - group['lr'] * group['weight_decay'])
97+
wd_scale = group['lr'] if not group['corrected_weight_decay'] else group['lr'] ** 2 / self.defaults['lr']
98+
p.data.mul_(1 - wd_scale * group['weight_decay'])
9099

91100
# Perform optimization step
92101
grad = p.grad

timm/optim/adopt.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
1111
year = {2024}
1212
}
13+
14+
References for added functionality:
15+
Cautious Optimizers: https://arxiv.org/abs/2411.16085
16+
Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
1317
"""
1418
from typing import cast, List, Optional, Tuple, Union
1519

@@ -66,6 +70,7 @@ def __init__(
6670
clip_exp: Optional[float] = 0.333,
6771
weight_decay: float = 0.0,
6872
decoupled: bool = False,
73+
corrected_weight_decay: bool = False,
6974
*,
7075
caution: bool = False,
7176
foreach: Optional[bool] = False,
@@ -98,6 +103,7 @@ def __init__(
98103
weight_decay=weight_decay,
99104
clip_exp=clip_exp,
100105
decoupled=decoupled,
106+
corrected_weight_decay=corrected_weight_decay,
101107
caution=caution,
102108
maximize=maximize,
103109
foreach=foreach,
@@ -115,6 +121,7 @@ def __setstate__(self, state):
115121
group.setdefault("differentiable", False)
116122
group.setdefault("clip_exp", None)
117123
group.setdefault("caution", False)
124+
group.setdefault("corrected_weight_decay", False)
118125
for p in group["params"]:
119126
p_state = self.state.get(p, [])
120127
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
@@ -222,6 +229,7 @@ def step(self, closure=None):
222229
lr=group["lr"],
223230
weight_decay=group["weight_decay"],
224231
clip_exp=group["clip_exp"],
232+
max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
225233
decoupled=group["decoupled"],
226234
eps=group["eps"],
227235
caution=group["caution"],
@@ -251,6 +259,7 @@ def _single_tensor_adopt(
251259
lr: Union[float, Tensor],
252260
weight_decay: float,
253261
clip_exp: Optional[float],
262+
max_lr: Optional[float],
254263
decoupled: bool,
255264
eps: float,
256265
caution: bool,
@@ -299,7 +308,8 @@ def _single_tensor_adopt(
299308
continue
300309

301310
if weight_decay != 0 and decoupled:
302-
param.add_(param, alpha=-lr * weight_decay)
311+
wd_scale = lr ** 2 / max_lr if max_lr is not None else lr
312+
param.add_(param, alpha=-wd_scale * weight_decay)
303313

304314
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
305315
normed_grad = grad.div(denom)
@@ -336,6 +346,7 @@ def _multi_tensor_adopt(
336346
lr: Union[float, Tensor],
337347
weight_decay: float,
338348
clip_exp: Optional[float],
349+
max_lr: Optional[float],
339350
decoupled: bool,
340351
eps: float,
341352
caution: bool,
@@ -410,7 +421,8 @@ def _multi_tensor_adopt(
410421
continue
411422

412423
if weight_decay != 0 and decoupled:
413-
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
424+
wd_scale = lr ** 2 / max_lr if max_lr is not None else lr
425+
torch._foreach_add_(device_params, device_params, alpha=-wd_scale * weight_decay)
414426

415427
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
416428
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
@@ -460,6 +472,7 @@ def adopt(
460472
lr: Union[float, Tensor],
461473
weight_decay: float,
462474
clip_exp: Optional[float],
475+
max_lr: Optional[float],
463476
decoupled: bool,
464477
eps: float,
465478
caution: bool,
@@ -498,6 +511,7 @@ def adopt(
498511
lr=lr,
499512
weight_decay=weight_decay,
500513
clip_exp=clip_exp,
514+
max_lr=max_lr,
501515
decoupled=decoupled,
502516
eps=eps,
503517
caution=caution,

0 commit comments

Comments
 (0)