From 21b240842029b4395cb59e99e34356ac400e9dea Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 13 Jan 2025 12:34:46 +0000 Subject: [PATCH] add tp_recompute_allgather to column linear --- src/nanotron/fp8/parameter.py | 8 ++++---- src/nanotron/models/llama.py | 11 +++-------- src/nanotron/optim/gradient_accumulator.py | 5 ++++- src/nanotron/parallel/tensor_parallel/nn.py | 5 +++-- tests/test_parameter.py | 6 ------ 5 files changed, 14 insertions(+), 21 deletions(-) diff --git a/src/nanotron/fp8/parameter.py b/src/nanotron/fp8/parameter.py index 3ed40f98..3e736724 100644 --- a/src/nanotron/fp8/parameter.py +++ b/src/nanotron/fp8/parameter.py @@ -9,6 +9,8 @@ from nanotron.fp8.meta import FP8Meta from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor +# from nanotron.config.fp8_config import FP8Args + class FP8Parameter(nn.Parameter): """ @@ -25,8 +27,6 @@ def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True, with torch.no_grad(): from typing import cast - from nanotron.config.fp8_config import FP8Args - if constants.CONFIG is None: sync_amax_in_weight = False else: @@ -41,8 +41,8 @@ def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True, self._data = FP8Tensor(data, dtype=dtype, interval=interval, sync=sync_amax_in_weight) # TODO(xrsrke): don't store fp32 raw data in memory after quantization - if constants.ITERATION_STEP == 1: - self.orig_data = data.data + # if constants.ITERATION_STEP == 1: + # self.orig_data = data.data # TODO(xrsrke): don't fixed these, take it from the FP8 recipe fp8e4m3_scale = update_scaling_factor( diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 0d69fe62..598b904a 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -230,8 +230,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, - name=f"model.decoder.{layer_idx}.mlp.gate_up_proj", - # tp_recompute_allgather=parallel_config.tp_recompute_allgather, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -240,7 +239,6 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - name=f"model.decoder.{layer_idx}.mlp.down_proj", ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -393,8 +391,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, - name=f"model.decoder.{layer_idx}.attention.qkv_proj", - # tp_recompute_allgather=parallel_config.tp_recompute_allgather, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -423,7 +420,6 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - name=f"model.decoder.{layer_idx}.attention.o_proj", ) self.attention = CoreAttention( @@ -916,8 +912,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, - "name": "model.lm_head", - # "tp_recompute_allgather": parallel_config.tp_recompute_allgather, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index ee532eb1..59694f4d 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -10,7 +10,6 @@ import nanotron.distributed as dist from nanotron import logging from nanotron.fp8.tensor import FP8Tensor -from nanotron.fp8.utils import is_overflow_underflow_nan from nanotron.parallel.parameters import NanotronParameter from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage @@ -285,6 +284,8 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: else: grad = half_param.grad + from nanotron.fp8.utils import is_overflow_underflow_nan + assert is_overflow_underflow_nan(grad) is False, f"Detected overflow/underflow/nan in {name} grad" fp32_grad = self.get_grad_buffer(name=name) @@ -309,6 +310,8 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: else: grad = fp32_grad master_param.grad = grad + from nanotron.fp8.utils import is_overflow_underflow_nan + assert is_overflow_underflow_nan(master_param.grad) is False @contextmanager diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index debc8f06..f604f1d7 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -55,7 +55,7 @@ def __init__( dtype: torch.dtype = None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - name: Optional[str] = None, + tp_recompute_allgather: bool = True, recipe: Optional[FP8LinearRecipe] = None, ): self.pg = pg @@ -65,7 +65,6 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size - self.name = name init_args = { "in_features": self.in_features, @@ -81,6 +80,7 @@ def __init__( self.mode = mode self.async_communication = async_communication + self.tp_recompute_allgather = tp_recompute_allgather if contiguous_chunks is not None: assert ( @@ -102,6 +102,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: diff --git a/tests/test_parameter.py b/tests/test_parameter.py index cb8ef797..04934707 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -124,10 +124,4 @@ def _test_create_param_that_share_metadata(parallel_context: ParallelContext): assert p1_k == p2_k assert p1_v == p2_v - # orig_hash = getattr(orig_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME) - # new_hash = getattr(new_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME) - - # assert new_hash == orig_hash - assert hash(new_param) == hash(orig_param) - parallel_context.destroy()