diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 6ff564a8..73ca3484 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -7,14 +7,7 @@ from datetime import datetime from functools import partial from math import ceil -from typing import ( - Any, - Dict, - Iterable, - List, - Optional, - Tuple, -) +from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np import torch @@ -29,7 +22,7 @@ from nanotron.distributed import ProcessGroup from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel -from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict +from nanotron.optim.base import BaseOptimizer, Optimizer from nanotron.optim.gradient_accumulator import ( FP32GradBucketManager, FP32GradientAccumulator, @@ -335,7 +328,7 @@ def basic_optimizer_builder(named_param_groups): if optimizer_args.optimizer_factory.name == "adamW": def optimizer(param_groups): - base_optimizer = torch.optim.AdamW( + return torch.optim.AdamW( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, @@ -343,11 +336,6 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) - # Replace the load_state_dict method with our custom implementation that enables CPU offload - base_optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict( - base_optimizer, state_dict, map_location=map_location - ) - return base_optimizer elif optimizer_args.optimizer_factory.name == "sgd": diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index 53b57284..7376a0b3 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -3,14 +3,21 @@ import torch -from nanotron.optim.base import BaseOptimizer, Optimizer +from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict class InheritFromOtherOptimizer(BaseOptimizer): def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]): - self.optimizer: Optimizer = optimizer self.id_to_name = id_to_name + # if self.optimizer is from torch we replace load_state_dict with the one from torch + if isinstance(optimizer, torch.optim.Optimizer): + # Replace the load_state_dict method with our custom implementation that enables CPU offload + optimizer.load_state_dict = lambda state_dict, map_location=None: custom_load_state_dict( + optimizer, state_dict, map_location=map_location + ) + self.optimizer: Optimizer = optimizer + def __getstate__(self): return self.optimizer.__getstate__()