diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 8fe0ba7b99..dc52460b04 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import accelerate import transformers import transformers.utils.fx @@ -104,6 +103,7 @@ GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, GaudiMixtralModel, + GaudiMixtralSparseMoeBlock, GaudiMllamaCrossAttentionDecoderLayer, GaudiMllamaForCausalLM, GaudiMllamaForConditionalGeneration, @@ -212,9 +212,6 @@ gaudi_MambaForCausalLM_prepare_inputs_for_generation, gaudi_MambaForCausalLM_update_model_kwargs_for_generation, gaudi_mistral_rmsnorm_forward, - gaudi_mixtral_block_dynamic_moe_forward, - gaudi_mixtral_block_moe_forward, - gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_rmsnorm_forward, gaudi_opt_attention_forward, gaudi_opt_decoder_forward, @@ -555,13 +552,7 @@ def adapt_transformers_to_gaudi(): transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel - transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.sparse_moe_forward = ( - gaudi_mixtral_block_sparse_moe_forward - ) - transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.dynamic_moe_forward = ( - gaudi_mixtral_block_dynamic_moe_forward - ) - transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_moe_forward + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock = GaudiMixtralSparseMoeBlock transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward transformers.models.mixtral.configuration_mixtral.MixtralConfig = MixtralConfig diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 13b84d48b1..6205604567 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -175,10 +175,8 @@ GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, GaudiMixtralModel, + GaudiMixtralSparseMoeBlock, MixtralConfig, - gaudi_mixtral_block_dynamic_moe_forward, - gaudi_mixtral_block_moe_forward, - gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_rmsnorm_forward, ) from .mllama import ( diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index 65bdca2fbd..7a612b80b3 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -4,8 +4,6 @@ GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, GaudiMixtralModel, - gaudi_mixtral_block_dynamic_moe_forward, - gaudi_mixtral_block_moe_forward, - gaudi_mixtral_block_sparse_moe_forward, + GaudiMixtralSparseMoeBlock, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index c11d7a277a..f75896f7f8 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -42,6 +42,7 @@ MixtralDecoderLayer, MixtralForCausalLM, MixtralModel, + MixtralSparseMoeBlock, apply_rotary_pos_emb, load_balancing_loss_func, ) @@ -148,6 +149,64 @@ def gaudi_mixtral_repeat_kv( return query_states, key_states, value_states, attention_mask +class GaudiMixtralSparseMoeBlock(MixtralSparseMoeBlock): + def __init__(self, config): + super().__init__(config) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + original_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + if is_deepspeed_available() and (not self.training): + from deepspeed import comm as dist + + if dist.is_initialized(): + output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, router_logits) + router_logits = torch.cat(output_tensors, dim=1) + + routing_weights, selected_experts = calculate_routing_tensors(router_logits, self.top_k, hidden_states.dtype) + + final_hidden_states = self.call_dynamic_moe_op( + hidden_states=hidden_states, + expert_routing_table=selected_experts, + router_weights=routing_weights, + ) + if is_deepspeed_available() and (not self.training): + from deepspeed import comm as dist + + if dist.is_initialized(): + dist.all_reduce(final_hidden_states) + return final_hidden_states.view(original_shape), router_logits + + def call_dynamic_moe_op( + self, + hidden_states, + expert_routing_table, + router_weights, + ): + # pre-processing for custom op inputs + w1_list = [expert.w1.weight for expert in self.experts] + w2_list = [expert.w2.weight for expert in self.experts] + w3_list = [expert.w3.weight for expert in self.experts] + + return torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w1=w1_list, + w3=w2_list, + w2=w3_list, + permuted_weights=True, + activation="silu", + experts_min=0, + experts_max=7, + ) + + class GaudiMixtralAttentionLongSequence: @staticmethod def forward(q, k, v, mask, causal, q_block_size): @@ -357,107 +416,6 @@ def forward( return attn_output, attn_weights, past_key_value -def gaudi_mixtral_block_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - # We need this workaround until moe op in hpu is supporting fp8 - if not self.training and not os.environ.get("QUANT_CONFIG"): - return self.dynamic_moe_forward(hidden_states) - - return self.sparse_moe_forward(hidden_states) - - -def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Copied from MixtralSparseMoeBlock.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - optimize expert forward, remove dynamic control and dynamic shape - """ - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - if is_deepspeed_available() and (not self.training): - from deepspeed import comm as dist - - if dist.is_initialized(): - output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())] - dist.all_gather(output_tensors, router_logits) - router_logits = torch.cat(output_tensors, dim=1) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - - padded_weights = torch.zeros( - (batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device - ) - padded_weights.scatter_(-1, selected_experts, routing_weights) - padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts) - padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - padded_weight = padded_weights[expert_idx] - current_state_static = hidden_states.reshape(-1, hidden_dim) - current_hidden_states_static = ( - expert_layer(current_state_static).reshape(-1, sequence_length, hidden_dim) * padded_weight - ) - final_hidden_states += current_hidden_states_static - # support long sequences exceeding 8192 - if not self.training and sequence_length > 8192: - htcore.mark_step() - - return final_hidden_states, router_logits - - -def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, sequence_length, hidden_dim = hidden_states.shape - original_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - if is_deepspeed_available() and (not self.training): - from deepspeed import comm as dist - - if dist.is_initialized(): - output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())] - dist.all_gather(output_tensors, router_logits) - router_logits = torch.cat(output_tensors, dim=1) - - routing_weights, selected_experts = calculate_routing_tensors(router_logits, self.top_k, hidden_states.dtype) - # pre-processing for custom op inputs - w1_list = [expert.w1.weight for expert in self.experts] - w2_list = [expert.w2.weight for expert in self.experts] - w3_list = [expert.w3.weight for expert in self.experts] - - final_hidden_states = torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=selected_experts, - router_weights=routing_weights, - w1=w1_list, - w3=w2_list, - w2=w3_list, - permuted_weights=True, - activation="silu", - experts_min=0, - experts_max=7, - ) - if is_deepspeed_available() and (not self.training): - from deepspeed import comm as dist - - if dist.is_initialized(): - dist.all_reduce(final_hidden_states) - return final_hidden_states.view(original_shape), router_logits - - def calculate_routing_tensors( score: torch.Tensor, topk: int, hidden_states_dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: