Skip to content

Commit

Permalink
add tp_recompute_allgather to column linear
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 13, 2025
1 parent c0cb423 commit 21b2408
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/nanotron/fp8/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor

# from nanotron.config.fp8_config import FP8Args


class FP8Parameter(nn.Parameter):
"""
Expand All @@ -25,8 +27,6 @@ def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True,
with torch.no_grad():
from typing import cast

from nanotron.config.fp8_config import FP8Args

if constants.CONFIG is None:
sync_amax_in_weight = False
else:
Expand All @@ -41,8 +41,8 @@ def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True,
self._data = FP8Tensor(data, dtype=dtype, interval=interval, sync=sync_amax_in_weight)
# TODO(xrsrke): don't store fp32 raw data in memory after quantization

if constants.ITERATION_STEP == 1:
self.orig_data = data.data
# if constants.ITERATION_STEP == 1:
# self.orig_data = data.data

# TODO(xrsrke): don't fixed these, take it from the FP8 recipe
fp8e4m3_scale = update_scaling_factor(
Expand Down
11 changes: 3 additions & 8 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
name=f"model.decoder.{layer_idx}.mlp.gate_up_proj",
# tp_recompute_allgather=parallel_config.tp_recompute_allgather,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
Expand All @@ -240,7 +239,6 @@ def __init__(
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
name=f"model.decoder.{layer_idx}.mlp.down_proj",
)
self.split_silu_mul = GLUActivation(config.hidden_act)

Expand Down Expand Up @@ -393,8 +391,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
name=f"model.decoder.{layer_idx}.attention.qkv_proj",
# tp_recompute_allgather=parallel_config.tp_recompute_allgather,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
if config.rope_interleaved:
Expand Down Expand Up @@ -423,7 +420,6 @@ def __init__(
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
name=f"model.decoder.{layer_idx}.attention.o_proj",
)

self.attention = CoreAttention(
Expand Down Expand Up @@ -916,8 +912,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"name": "model.lm_head",
# "tp_recompute_allgather": parallel_config.tp_recompute_allgather,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
5 changes: 4 additions & 1 deletion src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import nanotron.distributed as dist
from nanotron import logging
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.utils import is_overflow_underflow_nan
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage

Expand Down Expand Up @@ -285,6 +284,8 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
else:
grad = half_param.grad

from nanotron.fp8.utils import is_overflow_underflow_nan

assert is_overflow_underflow_nan(grad) is False, f"Detected overflow/underflow/nan in {name} grad"

fp32_grad = self.get_grad_buffer(name=name)
Expand All @@ -309,6 +310,8 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
else:
grad = fp32_grad
master_param.grad = grad
from nanotron.fp8.utils import is_overflow_underflow_nan

assert is_overflow_underflow_nan(master_param.grad) is False

@contextmanager
Expand Down
5 changes: 3 additions & 2 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
dtype: torch.dtype = None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
name: Optional[str] = None,
tp_recompute_allgather: bool = True,
recipe: Optional[FP8LinearRecipe] = None,
):
self.pg = pg
Expand All @@ -65,7 +65,6 @@ def __init__(

self.in_features = in_features
self.out_features = out_features // self.world_size
self.name = name

init_args = {
"in_features": self.in_features,
Expand All @@ -81,6 +80,7 @@ def __init__(

self.mode = mode
self.async_communication = async_communication
self.tp_recompute_allgather = tp_recompute_allgather

if contiguous_chunks is not None:
assert (
Expand All @@ -102,6 +102,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
)

def extra_repr(self) -> str:
Expand Down
6 changes: 0 additions & 6 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,4 @@ def _test_create_param_that_share_metadata(parallel_context: ParallelContext):
assert p1_k == p2_k
assert p1_v == p2_v

# orig_hash = getattr(orig_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)
# new_hash = getattr(new_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)

# assert new_hash == orig_hash
assert hash(new_param) == hash(orig_param)

parallel_context.destroy()

0 comments on commit 21b2408

Please sign in to comment.