Skip to content

Commit

Permalink
fix tensor slicing in NanotronParameter
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 15, 2025
1 parent 3dde0af commit 31af4f7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from functorch.dim import tree_map
Expand Down Expand Up @@ -266,10 +266,14 @@ def wrap(e):
unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)

OPS_THAT_RETURN_ORIGINAL_TENSOR = [
OPS_THAT_RETURN_ORIGINAL_TENSOR: List[Union[Callable, str]] = [
# NOTE: transpose operation
torch.ops.aten.t.default,
torch.ops.aten.view.default,
# NOTE: torch.ops.attn.slice.default doesn't exist
# so we use str(op) instead
"aten.slice.Tensor",
# torch.ops.attn.slice.default, # NOTE: param[local_slices]
torch.ops.aten.detach.default,
# NOTE: F.embedding()
torch.ops.aten.embedding.default,
Expand All @@ -292,7 +296,9 @@ def wrap(e):
else:
outputs = func(*unwrapped_args, **unwrapped_kwargs)

if func in OPS_THAT_RETURN_ORIGINAL_TENSOR:
if any(func == x for x in OPS_THAT_RETURN_ORIGINAL_TENSOR if not isinstance(x, str)) or any(
str(func) == x for x in OPS_THAT_RETURN_ORIGINAL_TENSOR if isinstance(x, str)
):
return outputs
else:
return tree_map(wrap, outputs)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,22 @@ def _test_create_param_that_share_metadata(parallel_context: ParallelContext):
assert p1_v == p2_v

parallel_context.destroy()


@rerun_if_address_is_in_use()
def test_slice_param():
init_distributed(tp=2, dp=1, pp=1)(_test_slice_param)()


def _test_slice_param(parallel_context: ParallelContext):
param = torch.nn.Parameter(torch.randn(16, 64))
split_config = SplitConfig(
split_dim=0,
contiguous_chunks=(8, 8),
)
param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config)

sliced = param[2:4, 4:6]

assert isinstance(sliced, torch.Tensor)
assert sliced.shape == (2, 2)
5 changes: 4 additions & 1 deletion tests/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ def _test_zero_optimizer_with_tp(
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
local_slices = local_global_slices_pair.local_slices
global_slices = local_global_slices_pair.global_slices
param[local_slices].copy_(ref_param[global_slices])

with torch.no_grad():
# param.data[local_slices].copy_(ref_param.data[global_slices])
param[local_slices].copy_(ref_param[global_slices])
else:
param.copy_(ref_param)

Expand Down

0 comments on commit 31af4f7

Please sign in to comment.