From 508798b9f532ab7a550f93ae1aab0e8c5e4097c1 Mon Sep 17 00:00:00 2001 From: jwendlan <82612519+jwendlan@users.noreply.github.com> Date: Tue, 27 Feb 2024 18:50:27 -0500 Subject: [PATCH 1/3] adding parallelize context for opt (#65) --- axonn/models/__init__.py | 1 + axonn/models/transformers/__init__.py | 29 +++++++++++++++++++ .../models}/transformers/modify_llama.py | 5 ++++ .../models}/transformers/modify_mistral.py | 0 .../models}/transformers/modify_opt.py | 5 ++++ 5 files changed, 40 insertions(+) create mode 100644 axonn/models/__init__.py create mode 100644 axonn/models/transformers/__init__.py rename {models => axonn/models}/transformers/modify_llama.py (93%) rename {models => axonn/models}/transformers/modify_mistral.py (100%) rename {models => axonn/models}/transformers/modify_opt.py (93%) diff --git a/axonn/models/__init__.py b/axonn/models/__init__.py new file mode 100644 index 0000000..ea17b1e --- /dev/null +++ b/axonn/models/__init__.py @@ -0,0 +1 @@ +# For parallelize context manager use diff --git a/axonn/models/transformers/__init__.py b/axonn/models/transformers/__init__.py new file mode 100644 index 0000000..27b6013 --- /dev/null +++ b/axonn/models/transformers/__init__.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager +from modify_opt import monkey_patch_opt_with_axonn +from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer +from modify_llama import monkey_patch_llama_with_axonn +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP + +modify_dict = { + "facebook/opt-125m": monkey_patch_opt_with_axonn, + "facebook/opt-350m": monkey_patch_opt_with_axonn, + "facebook/opt-1.3b": monkey_patch_opt_with_axonn, + "codellama/CodeLlama-70b-hf": monkey_patch_llama_with_axonn, + "codellama/CodeLlama-34b-hf": monkey_patch_llama_with_axonn, + "codellama/CodeLlama-13b-hf": monkey_patch_llama_with_axonn, + "codellama/CodeLlama-7b-hf": monkey_patch_llama_with_axonn, + "deepseek-ai/deepseek-coder-6.7b-base": monkey_patch_llama_with_axonn, + "meta-llama/Llama-2-7b-hf": monkey_patch_llama_with_axonn, +} + + +@contextmanager +def parallelize(model_id): + original_inits = modify_dict[model_id]() # call to monkey patch + try: + yield None + finally: + OPTAttention.__init__ = original_inits["OPTAttention"] + OPTDecoderLayer.__init__ = original_inits["OPTDecoderLayer"] + LlamaAttention.__init__ = original_inits["LlamaAttention"] + LlamaMLP.__init__ = original_inits["LlamaMLP"] diff --git a/models/transformers/modify_llama.py b/axonn/models/transformers/modify_llama.py similarity index 93% rename from models/transformers/modify_llama.py rename to axonn/models/transformers/modify_llama.py index af0a580..1403f5a 100644 --- a/models/transformers/modify_llama.py +++ b/axonn/models/transformers/modify_llama.py @@ -50,5 +50,10 @@ def modified_mlp_init(self, config): def monkey_patch_llama_with_axonn(): + original_inits = { + "LlamaAttention": LlamaAttention.__init__, + "LlamaMLP": LlamaMLP.__init__, + } LlamaAttention.__init__ = modified_attention_init LlamaMLP.__init__ = modified_mlp_init + return original_inits diff --git a/models/transformers/modify_mistral.py b/axonn/models/transformers/modify_mistral.py similarity index 100% rename from models/transformers/modify_mistral.py rename to axonn/models/transformers/modify_mistral.py diff --git a/models/transformers/modify_opt.py b/axonn/models/transformers/modify_opt.py similarity index 93% rename from models/transformers/modify_opt.py rename to axonn/models/transformers/modify_opt.py index ee1130e..1fd6969 100644 --- a/models/transformers/modify_opt.py +++ b/axonn/models/transformers/modify_opt.py @@ -56,5 +56,10 @@ def modified_decoder_init(self, config): def monkey_patch_opt_with_axonn(): + original_inits = { + "OPTAttention": OPTAttention.__init__, + "OPTDecoderLayer": OPTDecoderLayer.__init__, + } OPTAttention.__init__ = modified_attention_init OPTDecoderLayer.__init__ = modified_decoder_init + return original_inits From f975e5865dce7103aeac62fef3072bc134688158 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 27 Feb 2024 20:14:44 -0500 Subject: [PATCH 2/3] Removing the drop and gathers in depth tensor parallelism for the easy API (#66) --- axonn/intra_layer/fully_connected.py | 4 ---- axonn/tests/test_intra_layer_fc.py | 35 ++++++++++------------------ 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index f4561e4..6d2cfae 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -264,7 +264,6 @@ def forward( if not self.transpose: if scatter_input: x = Drop.apply(x, self.inner_group) - x = Drop.apply(x, self.depth_group, 0) x = AsyncLinear.apply( x, weight, @@ -278,11 +277,9 @@ def forward( ) if gather_output: x = Gather.apply(x, self.outer_group) - x = Gather.apply(x, self.depth_group, 0) else: if scatter_input: x = Drop.apply(x, self.outer_group) - x = Drop.apply(x, self.depth_group, 0) x = AsyncLinear.apply( x, @@ -297,7 +294,6 @@ def forward( ) if gather_output: x = Gather.apply(x, self.inner_group) - x = Gather.apply(x, self.depth_group, 0) if self.bias is None: return x diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index d8d5821..af09ab7 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -35,17 +35,14 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): outer_group = ax.comm_handle.outer_intra_layer_parallel_group depth_group = ax.comm_handle.depth_intra_layer_parallel_group + X_local = _drop(X, 0, depth_group) # divide rows of X along the depth tensor group + if not easy_tp: # manually divide input X_local = _drop( X, 1, inner_group ) # divide colunns of X along the inner tensor group # manually divide input - X_local = _drop( - X_local, 0, depth_group - ) # divide colunns of X along the inner tensor group - else: - X_local = X layer = Linear(in_features=H, out_features=H, bias=bias).cuda() layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() @@ -58,11 +55,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): with torch.no_grad(): # parallel FW pass Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + Y_parallel = _gather(Y_local.clone(), 0, depth_group) if not easy_tp: # gather output manually Y_parallel = _gather(Y_local.clone(), 1, outer_group) - Y_parallel = _gather(Y_parallel.clone(), 0, depth_group) - else: - Y_parallel = Y_local Y_sequential = layer_sequential(X) assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" @@ -117,22 +112,19 @@ def test_bw_pass( # test if load state dict works with a sharded checkpoint layer.load_state_dict(layer.state_dict()) + X_local = ( + _drop(X, 0, depth_group).detach().clone() + ) # divide colunns of X along the inner tensor group if not easy_tp: X_local = ( - _drop(X, 1, inner_group).detach().clone() + _drop(X_local, 1, inner_group).detach().clone() ) # divide colunns of X along the inner tensor group - X_local = ( - _drop(X_local, 0, depth_group).detach().clone() - ) # divide colunns of X along the inner tensor group - else: - X_local = X X_local.requires_grad = True + + Y_local_grad = _drop(Y_grad, 0, depth_group).detach().clone() if not easy_tp: - Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone() - Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone() - else: - Y_local_grad = Y_grad + Y_local_grad = _drop(Y_local_grad, 1, outer_group).detach().clone() with optimize_communication( overlap_all_reduce=comm_opt_level >= 1, @@ -144,8 +136,7 @@ def test_bw_pass( Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) Y_local.backward(Y_local_grad) - if not easy_tp: - sync_gradients(layer) + sync_gradients(layer) if comm_opt_level >= 3: clear_weights_cache() # sequential backward pass @@ -159,11 +150,9 @@ def test_bw_pass( layer_sequential.parameters(), max_norm=clip_grad_norm ) + X_grad_parallel = _gather(X_local.grad, 0, depth_group) if not easy_tp: - X_grad_parallel = _gather(X_local.grad, 0, depth_group) X_grad_parallel = _gather(X_grad_parallel, 1, inner_group) - else: - X_grad_parallel = X_local.grad assert torch.allclose( X_grad_parallel, X.grad From fb494f167d69be8ac61672156cebe7db48dd7d07 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 27 Feb 2024 23:11:19 -0500 Subject: [PATCH 3/3] change parallelize context to use AutoConfig (#67) --- axonn/__init__.py | 1 + axonn/models/__init__.py | 1 + axonn/models/transformers/__init__.py | 43 +++++++++++++---------- axonn/models/transformers/modify_llama.py | 31 ++++++++++------ axonn/models/transformers/modify_opt.py | 10 +++--- 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/axonn/__init__.py b/axonn/__init__.py index 8abdbbe..ab84c20 100644 --- a/axonn/__init__.py +++ b/axonn/__init__.py @@ -2,3 +2,4 @@ # See the top-level LICENSE file for details. # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from . import models # noqa: F401 diff --git a/axonn/models/__init__.py b/axonn/models/__init__.py index ea17b1e..da56eb3 100644 --- a/axonn/models/__init__.py +++ b/axonn/models/__init__.py @@ -1 +1,2 @@ # For parallelize context manager use +from . import transformers # noqa: F401 diff --git a/axonn/models/transformers/__init__.py b/axonn/models/transformers/__init__.py index 27b6013..db33249 100644 --- a/axonn/models/transformers/__init__.py +++ b/axonn/models/transformers/__init__.py @@ -1,29 +1,36 @@ from contextlib import contextmanager -from modify_opt import monkey_patch_opt_with_axonn -from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer -from modify_llama import monkey_patch_llama_with_axonn -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP +from transformers import AutoConfig +from .modify_opt import monkey_patch_opt_with_axonn, reverse_monkey_patch_opt_with_axonn +from .modify_llama import ( + monkey_patch_llama_with_axonn, + reverse_monkey_patch_llama_with_axonn, +) modify_dict = { - "facebook/opt-125m": monkey_patch_opt_with_axonn, - "facebook/opt-350m": monkey_patch_opt_with_axonn, - "facebook/opt-1.3b": monkey_patch_opt_with_axonn, - "codellama/CodeLlama-70b-hf": monkey_patch_llama_with_axonn, - "codellama/CodeLlama-34b-hf": monkey_patch_llama_with_axonn, - "codellama/CodeLlama-13b-hf": monkey_patch_llama_with_axonn, - "codellama/CodeLlama-7b-hf": monkey_patch_llama_with_axonn, - "deepseek-ai/deepseek-coder-6.7b-base": monkey_patch_llama_with_axonn, - "meta-llama/Llama-2-7b-hf": monkey_patch_llama_with_axonn, + "OPTForCausalLM": ( + monkey_patch_opt_with_axonn, + reverse_monkey_patch_opt_with_axonn, + ), + "LlamaForCausalLM": ( + monkey_patch_llama_with_axonn, + reverse_monkey_patch_llama_with_axonn, + ), } @contextmanager def parallelize(model_id): - original_inits = modify_dict[model_id]() # call to monkey patch + config = AutoConfig.from_pretrained(model_id) + architecture = config.architectures[0] + # config.architectures is a list, not sure what to do + # if it has multiple elements + assert ( + architecture in modify_dict + ), f"{architecture} has not been parallelized within AxoNN" + + monkey_patch_fn, reverse_monkey_patch_fn = modify_dict[architecture] + original_attention_init, original_mlp_init = monkey_patch_fn() try: yield None finally: - OPTAttention.__init__ = original_inits["OPTAttention"] - OPTDecoderLayer.__init__ = original_inits["OPTDecoderLayer"] - LlamaAttention.__init__ = original_inits["LlamaAttention"] - LlamaMLP.__init__ = original_inits["LlamaMLP"] + reverse_monkey_patch_fn(original_attention_init, original_mlp_init) diff --git a/axonn/models/transformers/modify_llama.py b/axonn/models/transformers/modify_llama.py index 1403f5a..4a1d2f5 100644 --- a/axonn/models/transformers/modify_llama.py +++ b/axonn/models/transformers/modify_llama.py @@ -1,10 +1,20 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN from axonn.intra_layer import Linear +from typing import Optional -def modified_attention_init(self, config): +def modified_attention_init(self, config, layer_idx: Optional[int] = None): super(LlamaAttention, self).__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( # noqa: F821 + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " # noqa: E501 + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " # noqa: E501 + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -16,9 +26,10 @@ def modified_attention_init(self, config): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" # noqa: E501 + f" and `num_heads`: {self.num_heads})." ) + self.q_proj = Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) @@ -32,9 +43,7 @@ def modified_attention_init(self, config): self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) - self.o_proj = Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias - ) + self.o_proj = Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self._init_rope() @@ -50,10 +59,12 @@ def modified_mlp_init(self, config): def monkey_patch_llama_with_axonn(): - original_inits = { - "LlamaAttention": LlamaAttention.__init__, - "LlamaMLP": LlamaMLP.__init__, - } + original_inits = LlamaAttention.__init__, LlamaMLP.__init__ LlamaAttention.__init__ = modified_attention_init LlamaMLP.__init__ = modified_mlp_init return original_inits + + +def reverse_monkey_patch_llama_with_axonn(original_attention_init, original_mlp_init): + LlamaAttention.__init__ = original_attention_init + LlamaMLP.__init__ = original_mlp_init diff --git a/axonn/models/transformers/modify_opt.py b/axonn/models/transformers/modify_opt.py index 1fd6969..7043748 100644 --- a/axonn/models/transformers/modify_opt.py +++ b/axonn/models/transformers/modify_opt.py @@ -56,10 +56,12 @@ def modified_decoder_init(self, config): def monkey_patch_opt_with_axonn(): - original_inits = { - "OPTAttention": OPTAttention.__init__, - "OPTDecoderLayer": OPTDecoderLayer.__init__, - } + original_inits = OPTAttention.__init__, OPTDecoderLayer.__init__ OPTAttention.__init__ = modified_attention_init OPTDecoderLayer.__init__ = modified_decoder_init return original_inits + + +def reverse_monkey_patch_opt_with_axonn(original_attention_init, original_mlp_init): + OPTAttention.__init__ = original_attention_init + OPTDecoderLayer.__init__ = original_mlp_init