Skip to content

Commit

Permalink
Add xlm-roberta model support for tei-gaudi use case (#1715)
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuanliu <[email protected]>
  • Loading branch information
kaixuanliu authored Jan 30, 2025
1 parent 167c37d commit 43cde94
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 0 deletions.
6 changes: 6 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
gaudi_xglm_attention_forward,
gaudi_xglm_decoder_layer_forward,
gaudi_xglm_model_forward,
gaudi_XLMRoberta_Sdpa_SelfAttention_forward,
)


Expand Down Expand Up @@ -740,6 +741,11 @@ def adapt_transformers_to_gaudi():
transformers.AutoConfig.register("minicpm3", MiniCPM3Config)
transformers.AutoModelForCausalLM.register(MiniCPM3Config, MiniCPM3ForCausalLM)

# Optimization for XLMRoberta model on Gaudi
transformers.models.xlm_roberta.modeling_xlm_roberta.XLMRobertaSdpaSelfAttention.forward = (
gaudi_XLMRoberta_Sdpa_SelfAttention_forward
)

# Optimization for Baichuan2 on Gaudi
transformers.AutoConfig.register("baichuan", BaichuanConfig)
transformers.AutoTokenizer.register(BaichuanConfig, slow_tokenizer_class=BaichuanTokenizer)
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,4 @@
gaudi_xglm_decoder_layer_forward,
gaudi_xglm_model_forward,
)
from .xlm_roberta import gaudi_XLMRoberta_Sdpa_SelfAttention_forward
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/xlm_roberta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modeling_xlm_roberta import gaudi_XLMRoberta_Sdpa_SelfAttention_forward
102 changes: 102 additions & 0 deletions optimum/habana/transformers/models/xlm_roberta/modeling_xlm_roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# coding=utf-8
# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch XLM-RoBERTa model."""

from typing import Optional, Tuple

import torch
import torch.utils.checkpoint
from habana_frameworks.torch.hpex.kernels import FusedSDPA

from optimum.utils import logging


logger = logging.get_logger(__name__)


def gaudi_XLMRoberta_Sdpa_SelfAttention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
r"""
Copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L295
Changes:
- Use HPU's FusedSDPA(fast mode for softmax) to replace `orch.nn.functional.scaled_dot_product_attention`
"""
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

if self.is_decoder:
past_key_value = (key_layer, value_layer)

is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)

attn_output = FusedSDPA.apply(
query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal, None, "fast", False
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs

0 comments on commit 43cde94

Please sign in to comment.