Skip to content

Commit

Permalink
[SW-208588] Add HPU fp8 Dynamic MOE (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
dudilester committed Feb 11, 2025
1 parent f48dda8 commit d6a4cb6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 118 deletions.
13 changes: 2 additions & 11 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import accelerate
import transformers
import transformers.utils.fx
Expand Down Expand Up @@ -104,6 +103,7 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
GaudiMixtralSparseMoeBlock,
GaudiMllamaCrossAttentionDecoderLayer,
GaudiMllamaForCausalLM,
GaudiMllamaForConditionalGeneration,
Expand Down Expand Up @@ -212,9 +212,6 @@
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
gaudi_opt_attention_forward,
gaudi_opt_decoder_forward,
Expand Down Expand Up @@ -555,13 +552,7 @@ def adapt_transformers_to_gaudi():
transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention
transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM
transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.sparse_moe_forward = (
gaudi_mixtral_block_sparse_moe_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.dynamic_moe_forward = (
gaudi_mixtral_block_dynamic_moe_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_moe_forward
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock = GaudiMixtralSparseMoeBlock
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer
transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward
transformers.models.mixtral.configuration_mixtral.MixtralConfig = MixtralConfig
Expand Down
4 changes: 1 addition & 3 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
GaudiMixtralSparseMoeBlock,
MixtralConfig,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
)
from .mllama import (
Expand Down
4 changes: 1 addition & 3 deletions optimum/habana/transformers/models/mixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
GaudiMixtralSparseMoeBlock,
gaudi_mixtral_rmsnorm_forward,
)
160 changes: 59 additions & 101 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralSparseMoeBlock,
apply_rotary_pos_emb,
load_balancing_loss_func,
)
Expand Down Expand Up @@ -148,6 +149,64 @@ def gaudi_mixtral_repeat_kv(
return query_states, key_states, value_states, attention_mask


class GaudiMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
super().__init__(config)

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
original_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, router_logits)
router_logits = torch.cat(output_tensors, dim=1)

routing_weights, selected_experts = calculate_routing_tensors(router_logits, self.top_k, hidden_states.dtype)

final_hidden_states = self.call_dynamic_moe_op(
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=routing_weights,
)
if is_deepspeed_available() and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
dist.all_reduce(final_hidden_states)
return final_hidden_states.view(original_shape), router_logits

def call_dynamic_moe_op(
self,
hidden_states,
expert_routing_table,
router_weights,
):
# pre-processing for custom op inputs
w1_list = [expert.w1.weight for expert in self.experts]
w2_list = [expert.w2.weight for expert in self.experts]
w3_list = [expert.w3.weight for expert in self.experts]

return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w1=w1_list,
w3=w2_list,
w2=w3_list,
permuted_weights=True,
activation="silu",
experts_min=0,
experts_max=7,
)


class GaudiMixtralAttentionLongSequence:
@staticmethod
def forward(q, k, v, mask, causal, q_block_size):
Expand Down Expand Up @@ -357,107 +416,6 @@ def forward(
return attn_output, attn_weights, past_key_value


def gaudi_mixtral_block_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# We need this workaround until moe op in hpu is supporting fp8
if not self.training and not os.environ.get("QUANT_CONFIG"):
return self.dynamic_moe_forward(hidden_states)

return self.sparse_moe_forward(hidden_states)


def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Copied from MixtralSparseMoeBlock.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py
The only differences are:
- optimize expert forward, remove dynamic control and dynamic shape
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, router_logits)
router_logits = torch.cat(output_tensors, dim=1)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

padded_weights = torch.zeros(
(batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, hidden_dim)
current_hidden_states_static = (
expert_layer(current_state_static).reshape(-1, sequence_length, hidden_dim) * padded_weight
)
final_hidden_states += current_hidden_states_static
# support long sequences exceeding 8192
if not self.training and sequence_length > 8192:
htcore.mark_step()

return final_hidden_states, router_logits


def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
original_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, router_logits)
router_logits = torch.cat(output_tensors, dim=1)

routing_weights, selected_experts = calculate_routing_tensors(router_logits, self.top_k, hidden_states.dtype)
# pre-processing for custom op inputs
w1_list = [expert.w1.weight for expert in self.experts]
w2_list = [expert.w2.weight for expert in self.experts]
w3_list = [expert.w3.weight for expert in self.experts]

final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=routing_weights,
w1=w1_list,
w3=w2_list,
w2=w3_list,
permuted_weights=True,
activation="silu",
experts_min=0,
experts_max=7,
)
if is_deepspeed_available() and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
dist.all_reduce(final_hidden_states)
return final_hidden_states.view(original_shape), router_logits


def calculate_routing_tensors(
score: torch.Tensor, topk: int, hidden_states_dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down

0 comments on commit d6a4cb6

Please sign in to comment.