Skip to content

Commit

Permalink
Merge branch 'develop' into add-comm-dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Feb 28, 2024
2 parents 79a76aa + fb494f1 commit ae8384b
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 33 deletions.
1 change: 1 addition & 0 deletions axonn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from . import models # noqa: F401
4 changes: 0 additions & 4 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def forward(
if not self.transpose:
if scatter_input:
x = Drop.apply(x, self.inner_group)
x = Drop.apply(x, self.depth_group, 0)
x = AsyncLinear.apply(
x,
weight,
Expand All @@ -278,11 +277,9 @@ def forward(
)
if gather_output:
x = Gather.apply(x, self.outer_group)
x = Gather.apply(x, self.depth_group, 0)
else:
if scatter_input:
x = Drop.apply(x, self.outer_group)
x = Drop.apply(x, self.depth_group, 0)

x = AsyncLinear.apply(
x,
Expand All @@ -297,7 +294,6 @@ def forward(
)
if gather_output:
x = Gather.apply(x, self.inner_group)
x = Gather.apply(x, self.depth_group, 0)

if self.bias is None:
return x
Expand Down
2 changes: 2 additions & 0 deletions axonn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# For parallelize context manager use
from . import transformers # noqa: F401
36 changes: 36 additions & 0 deletions axonn/models/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from contextlib import contextmanager
from transformers import AutoConfig
from .modify_opt import monkey_patch_opt_with_axonn, reverse_monkey_patch_opt_with_axonn
from .modify_llama import (
monkey_patch_llama_with_axonn,
reverse_monkey_patch_llama_with_axonn,
)

modify_dict = {
"OPTForCausalLM": (
monkey_patch_opt_with_axonn,
reverse_monkey_patch_opt_with_axonn,
),
"LlamaForCausalLM": (
monkey_patch_llama_with_axonn,
reverse_monkey_patch_llama_with_axonn,
),
}


@contextmanager
def parallelize(model_id):
config = AutoConfig.from_pretrained(model_id)
architecture = config.architectures[0]
# config.architectures is a list, not sure what to do
# if it has multiple elements
assert (
architecture in modify_dict
), f"{architecture} has not been parallelized within AxoNN"

monkey_patch_fn, reverse_monkey_patch_fn = modify_dict[architecture]
original_attention_init, original_mlp_init = monkey_patch_fn()
try:
yield None
finally:
reverse_monkey_patch_fn(original_attention_init, original_mlp_init)
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN
from axonn.intra_layer import Linear
from typing import Optional


def modified_attention_init(self, config):
def modified_attention_init(self, config, layer_idx: Optional[int] = None):
super(LlamaAttention, self).__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once( # noqa: F821
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " # noqa: E501
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " # noqa: E501
"when creating this class."
)

self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand All @@ -16,9 +26,10 @@ def modified_attention_init(self, config):

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})."
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" # noqa: E501
f" and `num_heads`: {self.num_heads})."
)

self.q_proj = Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
Expand All @@ -32,9 +43,7 @@ def modified_attention_init(self, config):
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
)
self.o_proj = Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()


Expand All @@ -50,5 +59,12 @@ def modified_mlp_init(self, config):


def monkey_patch_llama_with_axonn():
original_inits = LlamaAttention.__init__, LlamaMLP.__init__
LlamaAttention.__init__ = modified_attention_init
LlamaMLP.__init__ = modified_mlp_init
return original_inits


def reverse_monkey_patch_llama_with_axonn(original_attention_init, original_mlp_init):
LlamaAttention.__init__ = original_attention_init
LlamaMLP.__init__ = original_mlp_init
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,12 @@ def modified_decoder_init(self, config):


def monkey_patch_opt_with_axonn():
original_inits = OPTAttention.__init__, OPTDecoderLayer.__init__
OPTAttention.__init__ = modified_attention_init
OPTDecoderLayer.__init__ = modified_decoder_init
return original_inits


def reverse_monkey_patch_opt_with_axonn(original_attention_init, original_mlp_init):
OPTAttention.__init__ = original_attention_init
OPTDecoderLayer.__init__ = original_mlp_init
35 changes: 12 additions & 23 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,14 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias):
outer_group = ax.comm_handle.outer_intra_layer_parallel_group
depth_group = ax.comm_handle.depth_intra_layer_parallel_group

X_local = _drop(X, 0, depth_group) # divide rows of X along the depth tensor group

if not easy_tp:
# manually divide input
X_local = _drop(
X, 1, inner_group
) # divide colunns of X along the inner tensor group
# manually divide input
X_local = _drop(
X_local, 0, depth_group
) # divide colunns of X along the inner tensor group
else:
X_local = X

layer = Linear(in_features=H, out_features=H, bias=bias).cuda()
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda()
Expand All @@ -58,11 +55,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias):
with torch.no_grad():
# parallel FW pass
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_parallel = _gather(Y_local.clone(), 0, depth_group)
if not easy_tp: # gather output manually
Y_parallel = _gather(Y_local.clone(), 1, outer_group)
Y_parallel = _gather(Y_parallel.clone(), 0, depth_group)
else:
Y_parallel = Y_local
Y_sequential = layer_sequential(X)

assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match"
Expand Down Expand Up @@ -117,22 +112,19 @@ def test_bw_pass(
# test if load state dict works with a sharded checkpoint
layer.load_state_dict(layer.state_dict())

X_local = (
_drop(X, 0, depth_group).detach().clone()
) # divide colunns of X along the inner tensor group
if not easy_tp:
X_local = (
_drop(X, 1, inner_group).detach().clone()
_drop(X_local, 1, inner_group).detach().clone()
) # divide colunns of X along the inner tensor group
X_local = (
_drop(X_local, 0, depth_group).detach().clone()
) # divide colunns of X along the inner tensor group
else:
X_local = X

X_local.requires_grad = True

Y_local_grad = _drop(Y_grad, 0, depth_group).detach().clone()
if not easy_tp:
Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone()
Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone()
else:
Y_local_grad = Y_grad
Y_local_grad = _drop(Y_local_grad, 1, outer_group).detach().clone()

with optimize_communication(
overlap_all_reduce=comm_opt_level >= 1,
Expand All @@ -144,8 +136,7 @@ def test_bw_pass(
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_local.backward(Y_local_grad)

if not easy_tp:
sync_gradients(layer)
sync_gradients(layer)
if comm_opt_level >= 3:
clear_weights_cache()
# sequential backward pass
Expand All @@ -159,11 +150,9 @@ def test_bw_pass(
layer_sequential.parameters(), max_norm=clip_grad_norm
)

X_grad_parallel = _gather(X_local.grad, 0, depth_group)
if not easy_tp:
X_grad_parallel = _gather(X_local.grad, 0, depth_group)
X_grad_parallel = _gather(X_grad_parallel, 1, inner_group)
else:
X_grad_parallel = X_local.grad

assert torch.allclose(
X_grad_parallel, X.grad
Expand Down

0 comments on commit ae8384b

Please sign in to comment.