Skip to content

Commit

Permalink
initial infra for autotuning and disable non-expert communication pat…
Browse files Browse the repository at this point in the history
…hway for depth tp
  • Loading branch information
siddharth9820 committed Oct 24, 2024
1 parent 8b29b58 commit 161ba1d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 15 deletions.
16 changes: 13 additions & 3 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MPI4PY = False
import torch
import numpy as np
from typing import Sequence, Optional


class communication_handle:
Expand Down Expand Up @@ -160,11 +161,20 @@ def __init__(
self.inner_intra_layer_parallel_group,
self.outer_intra_layer_parallel_group,
self.depth_intra_layer_parallel_group,
) = self.get_intra_layer_groups(G_intra_r, G_intra_c, G_intra_d)
) = self.get_intra_layer_groups()

def get_intra_layer_groups(self, G_intra_r, G_intra_c, G_intra_d):
def get_intra_layer_groups(
self, tensor_parallel_dims: Optional[Sequence[int]] = None
):
G_inter, G_data, G_intra = self.G_inter, self.G_data, self.G_intra

if tensor_parallel_dims is None:
G_intra_r, G_intra_c, G_intra_d = (
self.G_intra_r,
self.G_intra_c,
self.G_intra_d,
)
else:
G_intra_r, G_intra_c, G_intra_d = tensor_parallel_dims
# first check if these communicators have already
# been created
group_key = (G_intra_r, G_intra_c, G_intra_d)
Expand Down
17 changes: 14 additions & 3 deletions axonn/intra_layer/automatic_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,28 @@ def is_parallelizable_embedding(num_embeddings, embedding_dim):


class patched_linear:
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None):
def __new__(
cls,
in_features,
out_features,
*args,
bias=True,
device=None,
dtype=None,
**kwargs,
):
if is_parallelizable_linear(in_features, out_features):
parallel_layer = Linear(in_features, out_features, bias=bias)
parallel_layer = Linear(
in_features, out_features, bias=bias, *args, **kwargs
)
if device is not None:
parallel_layer = parallel_layer.to(device)
if dtype is not None:
parallel_layer = parallel_layer.to(dtype)
return parallel_layer
else:
sequential_layer = reference_to_original_linear_class(
in_features, out_features, bias=bias
in_features, out_features, bias=bias, *args, **kwargs
)
if device is not None:
sequential_layer = sequential_layer.to(device)
Expand Down
23 changes: 14 additions & 9 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GatherChannelsScatterBatch,
gather_batch_sizes,
)
from typing import Optional, Sequence


# Wrapper for custom_fwd to handle different versions of PyTorch
Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(
skip_bias_add=False,
init_method=None,
expert_mode=False,
tensor_parallel_dims: Optional[Sequence[int]] = None,
**kwargs,
):
super(Linear, self).__init__()
Expand All @@ -205,13 +207,16 @@ def __init__(
# in_features are distributed across self.inner_group (X tensor parallel group)
# out_features are distributed across self.inner_group (Y tensor parallel group)
# if transpose is true then X and Y are swapped

if not transpose:
self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group
self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group
else:
self.inner_group = ax.comm_handle.outer_intra_layer_parallel_group
self.outer_group = ax.comm_handle.inner_intra_layer_parallel_group
if tensor_parallel_dims is not None and torch.distributed.get_rank() == 0:
print(
"Manually setting TP dims for a layer with shape",
f" - {(in_features, out_features)} | tp-dims = {tensor_parallel_dims}",
)
self.inner_group, self.outer_group, self.depth_group = (
ax.comm_handle.get_intra_layer_groups(tensor_parallel_dims)
)
if transpose:
self.inner_group, self.outer_group = self.outer_group, self.inner_group

# depth_group is the Z tensor parallel group (akin to FSDP)
self.depth_group = ax.comm_handle.depth_intra_layer_parallel_group
Expand Down Expand Up @@ -303,7 +308,7 @@ def forward(
original_shape_x = x.shape
x = x.reshape(-1, x.shape[-1])
weight = self.weight
if not self.expert_mode:
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from pure data parallelism
# to 4D hybrid parallelism
inner_group_batch_sizes = gather_batch_sizes(x.shape[0], self.inner_group)
Expand All @@ -321,7 +326,7 @@ def forward(
(self.local_out_features, self.local_in_features),
cache_weights_in_all_gather,
)
if not self.expert_mode:
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from 4D hybrid parallelism
# to pure data parallelism
x = GatherChannelsScatterBatch.apply(
Expand Down

0 comments on commit 161ba1d

Please sign in to comment.