diff --git a/src/nanotron/parallel/parameters.py b/src/nanotron/parallel/parameters.py index 702a1e80..28995e98 100644 --- a/src/nanotron/parallel/parameters.py +++ b/src/nanotron/parallel/parameters.py @@ -1,6 +1,6 @@ import dataclasses from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from functorch.dim import tree_map @@ -266,10 +266,14 @@ def wrap(e): unwrapped_args = tree_map(unwrap, args) unwrapped_kwargs = tree_map(unwrap, kwargs) - OPS_THAT_RETURN_ORIGINAL_TENSOR = [ + OPS_THAT_RETURN_ORIGINAL_TENSOR: List[Union[Callable, str]] = [ # NOTE: transpose operation torch.ops.aten.t.default, torch.ops.aten.view.default, + # NOTE: torch.ops.attn.slice.default doesn't exist + # so we use str(op) instead + "aten.slice.Tensor", + # torch.ops.attn.slice.default, # NOTE: param[local_slices] torch.ops.aten.detach.default, # NOTE: F.embedding() torch.ops.aten.embedding.default, @@ -292,7 +296,9 @@ def wrap(e): else: outputs = func(*unwrapped_args, **unwrapped_kwargs) - if func in OPS_THAT_RETURN_ORIGINAL_TENSOR: + if any(func == x for x in OPS_THAT_RETURN_ORIGINAL_TENSOR if not isinstance(x, str)) or any( + str(func) == x for x in OPS_THAT_RETURN_ORIGINAL_TENSOR if isinstance(x, str) + ): return outputs else: return tree_map(wrap, outputs) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 1630a6e2..580f73d1 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -108,3 +108,22 @@ def _test_create_param_that_share_metadata(parallel_context: ParallelContext): assert p1_v == p2_v parallel_context.destroy() + + +@rerun_if_address_is_in_use() +def test_slice_param(): + init_distributed(tp=2, dp=1, pp=1)(_test_slice_param)() + + +def _test_slice_param(parallel_context: ParallelContext): + param = torch.nn.Parameter(torch.randn(16, 64)) + split_config = SplitConfig( + split_dim=0, + contiguous_chunks=(8, 8), + ) + param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config) + + sliced = param[2:4, 4:6] + + assert isinstance(sliced, torch.Tensor) + assert sliced.shape == (2, 2) diff --git a/tests/test_zero.py b/tests/test_zero.py index 830ea739..98e1b0df 100644 --- a/tests/test_zero.py +++ b/tests/test_zero.py @@ -276,7 +276,10 @@ def _test_zero_optimizer_with_tp( for local_global_slices_pair in sharded_info.local_global_slices_pairs: local_slices = local_global_slices_pair.local_slices global_slices = local_global_slices_pair.global_slices - param[local_slices].copy_(ref_param[global_slices]) + + with torch.no_grad(): + # param.data[local_slices].copy_(ref_param.data[global_slices]) + param[local_slices].copy_(ref_param[global_slices]) else: param.copy_(ref_param)