From 3a5d4ffce20f37800addf95976c68cc345a3d586 Mon Sep 17 00:00:00 2001 From: Suhas Kotha Date: Thu, 23 Jan 2025 15:05:04 -0800 Subject: [PATCH] sgd implementation --- src/levanter/optim/config.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index 40be5576d..bda83621d 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -322,3 +322,33 @@ def _optimizer(learning_rate): return optimizer return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) + + +@OptimizerConfig.register_subclass("sgd") +@dataclass +class SGDConfig(OptimizerConfig): + momentum: float = 0.9 + nesterov: bool = False + max_grad_norm: Optional[float] = 1.0 + + def build(self, num_train_steps): + """Creates the SGD optimizer""" + def _optimizer(learning_rate): + components = [] + + if self.max_grad_norm: + components.append(optax.clip_by_global_norm(self.max_grad_norm)) + + if self.momentum > 0: + components.append(optax.trace(decay=self.momentum, nesterov=self.nesterov)) + + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = optax.chain(*components) + return optimizer + + return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps))