diff --git a/axonn/__init__.py b/axonn/__init__.py index 8abdbbe..ab84c20 100644 --- a/axonn/__init__.py +++ b/axonn/__init__.py @@ -2,3 +2,4 @@ # See the top-level LICENSE file for details. # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from . import models # noqa: F401 diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index f4561e4..6d2cfae 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -264,7 +264,6 @@ def forward( if not self.transpose: if scatter_input: x = Drop.apply(x, self.inner_group) - x = Drop.apply(x, self.depth_group, 0) x = AsyncLinear.apply( x, weight, @@ -278,11 +277,9 @@ def forward( ) if gather_output: x = Gather.apply(x, self.outer_group) - x = Gather.apply(x, self.depth_group, 0) else: if scatter_input: x = Drop.apply(x, self.outer_group) - x = Drop.apply(x, self.depth_group, 0) x = AsyncLinear.apply( x, @@ -297,7 +294,6 @@ def forward( ) if gather_output: x = Gather.apply(x, self.inner_group) - x = Gather.apply(x, self.depth_group, 0) if self.bias is None: return x diff --git a/axonn/models/__init__.py b/axonn/models/__init__.py new file mode 100644 index 0000000..da56eb3 --- /dev/null +++ b/axonn/models/__init__.py @@ -0,0 +1,2 @@ +# For parallelize context manager use +from . import transformers # noqa: F401 diff --git a/axonn/models/transformers/__init__.py b/axonn/models/transformers/__init__.py new file mode 100644 index 0000000..db33249 --- /dev/null +++ b/axonn/models/transformers/__init__.py @@ -0,0 +1,36 @@ +from contextlib import contextmanager +from transformers import AutoConfig +from .modify_opt import monkey_patch_opt_with_axonn, reverse_monkey_patch_opt_with_axonn +from .modify_llama import ( + monkey_patch_llama_with_axonn, + reverse_monkey_patch_llama_with_axonn, +) + +modify_dict = { + "OPTForCausalLM": ( + monkey_patch_opt_with_axonn, + reverse_monkey_patch_opt_with_axonn, + ), + "LlamaForCausalLM": ( + monkey_patch_llama_with_axonn, + reverse_monkey_patch_llama_with_axonn, + ), +} + + +@contextmanager +def parallelize(model_id): + config = AutoConfig.from_pretrained(model_id) + architecture = config.architectures[0] + # config.architectures is a list, not sure what to do + # if it has multiple elements + assert ( + architecture in modify_dict + ), f"{architecture} has not been parallelized within AxoNN" + + monkey_patch_fn, reverse_monkey_patch_fn = modify_dict[architecture] + original_attention_init, original_mlp_init = monkey_patch_fn() + try: + yield None + finally: + reverse_monkey_patch_fn(original_attention_init, original_mlp_init) diff --git a/models/transformers/modify_llama.py b/axonn/models/transformers/modify_llama.py similarity index 61% rename from models/transformers/modify_llama.py rename to axonn/models/transformers/modify_llama.py index af0a580..4a1d2f5 100644 --- a/models/transformers/modify_llama.py +++ b/axonn/models/transformers/modify_llama.py @@ -1,10 +1,20 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN from axonn.intra_layer import Linear +from typing import Optional -def modified_attention_init(self, config): +def modified_attention_init(self, config, layer_idx: Optional[int] = None): super(LlamaAttention, self).__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( # noqa: F821 + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " # noqa: E501 + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " # noqa: E501 + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -16,9 +26,10 @@ def modified_attention_init(self, config): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" # noqa: E501 + f" and `num_heads`: {self.num_heads})." ) + self.q_proj = Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) @@ -32,9 +43,7 @@ def modified_attention_init(self, config): self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) - self.o_proj = Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias - ) + self.o_proj = Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self._init_rope() @@ -50,5 +59,12 @@ def modified_mlp_init(self, config): def monkey_patch_llama_with_axonn(): + original_inits = LlamaAttention.__init__, LlamaMLP.__init__ LlamaAttention.__init__ = modified_attention_init LlamaMLP.__init__ = modified_mlp_init + return original_inits + + +def reverse_monkey_patch_llama_with_axonn(original_attention_init, original_mlp_init): + LlamaAttention.__init__ = original_attention_init + LlamaMLP.__init__ = original_mlp_init diff --git a/models/transformers/modify_mistral.py b/axonn/models/transformers/modify_mistral.py similarity index 100% rename from models/transformers/modify_mistral.py rename to axonn/models/transformers/modify_mistral.py diff --git a/models/transformers/modify_opt.py b/axonn/models/transformers/modify_opt.py similarity index 88% rename from models/transformers/modify_opt.py rename to axonn/models/transformers/modify_opt.py index ee1130e..7043748 100644 --- a/models/transformers/modify_opt.py +++ b/axonn/models/transformers/modify_opt.py @@ -56,5 +56,12 @@ def modified_decoder_init(self, config): def monkey_patch_opt_with_axonn(): + original_inits = OPTAttention.__init__, OPTDecoderLayer.__init__ OPTAttention.__init__ = modified_attention_init OPTDecoderLayer.__init__ = modified_decoder_init + return original_inits + + +def reverse_monkey_patch_opt_with_axonn(original_attention_init, original_mlp_init): + OPTAttention.__init__ = original_attention_init + OPTDecoderLayer.__init__ = original_mlp_init diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index d8d5821..af09ab7 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -35,17 +35,14 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): outer_group = ax.comm_handle.outer_intra_layer_parallel_group depth_group = ax.comm_handle.depth_intra_layer_parallel_group + X_local = _drop(X, 0, depth_group) # divide rows of X along the depth tensor group + if not easy_tp: # manually divide input X_local = _drop( X, 1, inner_group ) # divide colunns of X along the inner tensor group # manually divide input - X_local = _drop( - X_local, 0, depth_group - ) # divide colunns of X along the inner tensor group - else: - X_local = X layer = Linear(in_features=H, out_features=H, bias=bias).cuda() layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() @@ -58,11 +55,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): with torch.no_grad(): # parallel FW pass Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + Y_parallel = _gather(Y_local.clone(), 0, depth_group) if not easy_tp: # gather output manually Y_parallel = _gather(Y_local.clone(), 1, outer_group) - Y_parallel = _gather(Y_parallel.clone(), 0, depth_group) - else: - Y_parallel = Y_local Y_sequential = layer_sequential(X) assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" @@ -117,22 +112,19 @@ def test_bw_pass( # test if load state dict works with a sharded checkpoint layer.load_state_dict(layer.state_dict()) + X_local = ( + _drop(X, 0, depth_group).detach().clone() + ) # divide colunns of X along the inner tensor group if not easy_tp: X_local = ( - _drop(X, 1, inner_group).detach().clone() + _drop(X_local, 1, inner_group).detach().clone() ) # divide colunns of X along the inner tensor group - X_local = ( - _drop(X_local, 0, depth_group).detach().clone() - ) # divide colunns of X along the inner tensor group - else: - X_local = X X_local.requires_grad = True + + Y_local_grad = _drop(Y_grad, 0, depth_group).detach().clone() if not easy_tp: - Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone() - Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone() - else: - Y_local_grad = Y_grad + Y_local_grad = _drop(Y_local_grad, 1, outer_group).detach().clone() with optimize_communication( overlap_all_reduce=comm_opt_level >= 1, @@ -144,8 +136,7 @@ def test_bw_pass( Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) Y_local.backward(Y_local_grad) - if not easy_tp: - sync_gradients(layer) + sync_gradients(layer) if comm_opt_level >= 3: clear_weights_cache() # sequential backward pass @@ -159,11 +150,9 @@ def test_bw_pass( layer_sequential.parameters(), max_norm=clip_grad_norm ) + X_grad_parallel = _gather(X_local.grad, 0, depth_group) if not easy_tp: - X_grad_parallel = _gather(X_local.grad, 0, depth_group) X_grad_parallel = _gather(X_grad_parallel, 1, inner_group) - else: - X_grad_parallel = X_local.grad assert torch.allclose( X_grad_parallel, X.grad