Skip to content

Commit

Permalink
setup custom_load_state_dict for all torch optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 22, 2024
1 parent bc25a35 commit 3a2a6c7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
18 changes: 3 additions & 15 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
from datetime import datetime
from functools import partial
from math import ceil
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
)
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -29,7 +22,7 @@
from nanotron.distributed import ProcessGroup
from nanotron.logging import LogItem, log_rank
from nanotron.models.base import NanotronModel
from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict
from nanotron.optim.base import BaseOptimizer, Optimizer
from nanotron.optim.gradient_accumulator import (
FP32GradBucketManager,
FP32GradientAccumulator,
Expand Down Expand Up @@ -335,19 +328,14 @@ def basic_optimizer_builder(named_param_groups):
if optimizer_args.optimizer_factory.name == "adamW":

def optimizer(param_groups):
base_optimizer = torch.optim.AdamW(
return torch.optim.AdamW(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
eps=optimizer_args.optimizer_factory.adam_eps,
betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2),
fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
)
# Replace the load_state_dict method with our custom implementation that enables CPU offload
base_optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict(
base_optimizer, state_dict, map_location=map_location
)
return base_optimizer

elif optimizer_args.optimizer_factory.name == "sgd":

Expand Down
11 changes: 9 additions & 2 deletions src/nanotron/optim/inherit_from_other_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@

import torch

from nanotron.optim.base import BaseOptimizer, Optimizer
from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict


class InheritFromOtherOptimizer(BaseOptimizer):
def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]):
self.optimizer: Optimizer = optimizer
self.id_to_name = id_to_name

# if self.optimizer is from torch we replace load_state_dict with the one from torch
if isinstance(optimizer, torch.optim.Optimizer):
# Replace the load_state_dict method with our custom implementation that enables CPU offload
optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict(
optimizer, state_dict, map_location=map_location
)
self.optimizer: Optimizer = optimizer

def __getstate__(self):
return self.optimizer.__getstate__()

Expand Down

0 comments on commit 3a2a6c7

Please sign in to comment.