diff --git a/optimum/habana/transformers/modeling_rope_utils.py b/optimum/habana/transformers/modeling_rope_utils.py index 639219c9ab..0a05e51a2f 100644 --- a/optimum/habana/transformers/modeling_rope_utils.py +++ b/optimum/habana/transformers/modeling_rope_utils.py @@ -88,6 +88,9 @@ def _dynamic_frequency_update(self, seq_len, device): self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index 3edab86a60..f36c9cd578 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union import torch -from torch.nn import CrossEntropyLoss from torch.nn import functional as F from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions @@ -544,6 +543,8 @@ def forward( `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ + # Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly + num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None) if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( @@ -577,14 +578,12 @@ def forward( if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + num_items_in_batch=num_items_in_batch, ) if not return_dict: diff --git a/optimum/habana/transformers/models/codegen/modeling_codegen.py b/optimum/habana/transformers/models/codegen/modeling_codegen.py index cfe450ab6c..963cead407 100644 --- a/optimum/habana/transformers/models/codegen/modeling_codegen.py +++ b/optimum/habana/transformers/models/codegen/modeling_codegen.py @@ -2,7 +2,6 @@ import torch import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.codegen.modeling_codegen import ( @@ -164,6 +163,7 @@ def gaudi_codegen_model_forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + **kwargs, # NOOP kwargs, for now ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from CodeGenBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py @@ -397,6 +397,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -432,12 +433,13 @@ def forward( if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) loss = loss.to(hidden_states.dtype) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index ddc52a4a74..508fab27af 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -27,7 +27,6 @@ import habana_frameworks.torch.core as htcore from torch import nn -from torch.nn import CrossEntropyLoss from torch.nn import functional as F from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa @@ -1040,6 +1039,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1094,14 +1094,11 @@ def forward( loss = None if labels is not None: - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, ) if not return_dict: diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index a4de41d29a..d2d4209d0e 100755 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -603,6 +603,7 @@ def forward( flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, + **kwargs, # NOOP kwarg for now ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from GemmaModel.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py diff --git a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py index 505c8c3ac3..7178d8f970 100755 --- a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py +++ b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py @@ -143,6 +143,9 @@ def _dynamic_frequency_update(self, seq_len, device): self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index 546ee7ef47..e42a8308fa 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -516,6 +516,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -546,14 +547,13 @@ def forward( loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index f01255624f..608c272135 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -22,7 +22,6 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( GPTBigCodeAttention, @@ -806,6 +805,7 @@ def forward( flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, cache_idx: Optional[int] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -842,12 +842,12 @@ def forward( loss = None if labels is not None: - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/optimum/habana/transformers/models/gpt_neo/modeling_gpt_neo.py b/optimum/habana/transformers/models/gpt_neo/modeling_gpt_neo.py index b5ef987752..1cb65bffd0 100644 --- a/optimum/habana/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/optimum/habana/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -1,7 +1,6 @@ from typing import Optional, Tuple, Union import torch -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, @@ -305,7 +304,9 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -341,12 +342,13 @@ def forward( # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 lm_logits = lm_logits.to(torch.float32) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) lm_logits = lm_logits.to(hidden_states.dtype) loss = loss.to(hidden_states.dtype) diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index 4f4a152c67..dd41d7b557 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -408,6 +408,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + **kwargs, # Unused for now, mostly for the loss correction ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index d4da76d6f2..a719dc645a 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -3,7 +3,6 @@ import habana_frameworks.torch.core as htcore import torch from torch import nn -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.gptj.configuration_gptj import GPTJConfig @@ -662,6 +661,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -699,12 +699,13 @@ def forward( if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) loss = loss.to(hidden_states.dtype) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ce795c0cd8..e10d9e683e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -149,6 +149,9 @@ def _dynamic_frequency_update(self, seq_len, device): self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len diff --git a/optimum/habana/transformers/models/mpt/modeling_mpt.py b/optimum/habana/transformers/models/mpt/modeling_mpt.py index 309e0d7acc..7219ac0f29 100755 --- a/optimum/habana/transformers/models/mpt/modeling_mpt.py +++ b/optimum/habana/transformers/models/mpt/modeling_mpt.py @@ -19,7 +19,6 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.mpt.modeling_mpt import ( MptAttention, @@ -244,6 +243,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: Optional[torch.Tensor] = None, + **kwargs, # NOOP kwargs, for now ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: """ Copied from MptModel.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py @@ -444,6 +444,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """ Inherits from MptForCausalLM: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py @@ -477,14 +478,12 @@ def forward( if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, ) if not return_dict: diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py index 3a7c99d96e..0d7afa4de8 100644 --- a/optimum/habana/transformers/models/opt/modeling_opt.py +++ b/optimum/habana/transformers/models/opt/modeling_opt.py @@ -1,7 +1,6 @@ from typing import List, Optional, Tuple, Union import torch -from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.opt.configuration_opt import OPTConfig @@ -496,6 +495,7 @@ def forward( return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -524,12 +524,12 @@ def forward( if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/optimum/habana/transformers/models/paligemma/modeling_paligemma.py b/optimum/habana/transformers/models/paligemma/modeling_paligemma.py index 1d2db48d41..6f2a2817d0 100644 --- a/optimum/habana/transformers/models/paligemma/modeling_paligemma.py +++ b/optimum/habana/transformers/models/paligemma/modeling_paligemma.py @@ -48,7 +48,7 @@ def forward( return_dict: Optional[bool] = None, num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, - **kwargs, + **lm_kwargs, ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: """ Inherits from PaliGemmaForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/paligemma/modeling_paligemma.py#L402 @@ -109,7 +109,7 @@ def forward( labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) outputs = self.language_model( attention_mask=causal_mask, @@ -124,6 +124,7 @@ def forward( # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here # num_logits_to_keep=num_logits_to_keep, token_idx=token_idx, + **lm_kwargs, ) logits = outputs.logits diff --git a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py index 3e56f3c9e2..62fbe16f3c 100644 --- a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py +++ b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py @@ -3,7 +3,6 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.persimmon.configuration_persimmon import PersimmonConfig @@ -365,6 +364,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from PersimmonForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/persimmon/modeling_persimmon.py @@ -399,16 +399,12 @@ def forward( loss = None if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py index 97a78077d7..7457b8f886 100644 --- a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py +++ b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py @@ -3,7 +3,6 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.stablelm.configuration_stablelm import StableLmConfig @@ -384,6 +383,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from StableLmForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py @@ -416,16 +416,12 @@ def forward( loss = None if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/optimum/habana/transformers/models/xglm/modeling_xglm.py b/optimum/habana/transformers/models/xglm/modeling_xglm.py index f69eb3b990..289e0eb55f 100644 --- a/optimum/habana/transformers/models/xglm/modeling_xglm.py +++ b/optimum/habana/transformers/models/xglm/modeling_xglm.py @@ -2,7 +2,6 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.xglm.modeling_xglm import XGLMForCausalLM from transformers.utils import logging @@ -405,6 +404,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """ Inherits from XGLMForCausalLM: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py @@ -440,13 +440,13 @@ def forward( loss = None if labels is not None: - # shift labels and add a pad token to the end - shift_labels = labels.new_zeros(labels.shape) - shift_labels[:, :-1] = labels[:, 1:].clone() - shift_labels[:, -1] = self.config.pad_token_id - - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + pad_token_id=self.config.pad_token_id, + **kwargs, + ) if not return_dict: output = (logits,) + outputs[1:]