Skip to content

Deprecate modeling_utils.py classes #37298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions docs/source/en/internal/modeling_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,6 @@ Most of those are only useful if you are studying the code of the models in the

[[autodoc]] pytorch_utils.Conv1D

[[autodoc]] modeling_utils.PoolerStartLogits
- forward

[[autodoc]] modeling_utils.PoolerEndLogits
- forward

[[autodoc]] modeling_utils.PoolerAnswerClass
- forward

[[autodoc]] modeling_utils.SquadHeadOutput

[[autodoc]] modeling_utils.SQuADHead
- forward

[[autodoc]] modeling_utils.SequenceSummary
- forward

## PyTorch Helper Functions

[[autodoc]] pytorch_utils.apply_chunking_to_forward
Expand Down
17 changes: 0 additions & 17 deletions docs/source/ja/internal/modeling_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,6 @@ rendered properly in your Markdown viewer.

[[autodoc]] pytorch_utils.Conv1D

[[autodoc]] modeling_utils.PoolerStartLogits
- forward

[[autodoc]] modeling_utils.PoolerEndLogits
- forward

[[autodoc]] modeling_utils.PoolerAnswerClass
- forward

[[autodoc]] modeling_utils.SquadHeadOutput

[[autodoc]] modeling_utils.SQuADHead
- forward

[[autodoc]] modeling_utils.SequenceSummary
- forward

## PyTorch Helper Functions

[[autodoc]] pytorch_utils.apply_chunking_to_forward
Expand Down
17 changes: 0 additions & 17 deletions docs/source/ko/internal/modeling_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,6 @@ rendered properly in your Markdown viewer.

[[autodoc]] pytorch_utils.Conv1D

[[autodoc]] modeling_utils.PoolerStartLogits
- forward

[[autodoc]] modeling_utils.PoolerEndLogits
- forward

[[autodoc]] modeling_utils.PoolerAnswerClass
- forward

[[autodoc]] modeling_utils.SquadHeadOutput

[[autodoc]] modeling_utils.SQuADHead
- forward

[[autodoc]] modeling_utils.SequenceSummary
- forward

## PyTorch 헬퍼(helper) 함수 [[transformers.apply_chunking_to_forward]]

[[autodoc]] pytorch_utils.apply_chunking_to_forward
Expand Down
17 changes: 0 additions & 17 deletions docs/source/zh/internal/modeling_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,6 @@ rendered properly in your Markdown viewer.

[[autodoc]] pytorch_utils.Conv1D

[[autodoc]] modeling_utils.PoolerStartLogits
- forward

[[autodoc]] modeling_utils.PoolerEndLogits
- forward

[[autodoc]] modeling_utils.PoolerAnswerClass
- forward

[[autodoc]] modeling_utils.SquadHeadOutput

[[autodoc]] modeling_utils.SQuADHead
- forward

[[autodoc]] modeling_utils.SequenceSummary
- forward

## PyTorch帮助函数

[[autodoc]] pytorch_utils.apply_chunking_to_forward
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5384,6 +5384,10 @@ class PoolerStartLogits(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, 1)
logger.warning_once(
"[DEPRECATION WARNING] `PoolerStartLogits` is deprecated and will be removed in v4.53. "
"Please use model-specific class, e.g. `XLMPoolerStartLogits`."
)

def forward(
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
Expand Down Expand Up @@ -5426,6 +5430,10 @@ def __init__(self, config: PretrainedConfig):
self.activation = nn.Tanh()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense_1 = nn.Linear(config.hidden_size, 1)
logger.warning_once(
"[DEPRECATION WARNING] `PoolerEndLogits` is deprecated and will be removed in v4.53. "
"Please use model-specific class, e.g. `XLMPoolerEndLogits`."
)

def forward(
self,
Expand Down Expand Up @@ -5493,6 +5501,10 @@ def __init__(self, config):
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
logger.warning_once(
"[DEPRECATION WARNING] `PoolerAnswerClass` is deprecated and will be removed in v4.53. "
"Please use model-specific class, e.g. `XLMPoolerAnswerClass`."
)

def forward(
self,
Expand Down Expand Up @@ -5574,6 +5586,12 @@ class SquadHeadOutput(ModelOutput):
end_top_index: Optional[torch.LongTensor] = None
cls_logits: Optional[torch.FloatTensor] = None

def __post_init__(self):
logger.warning_once(
"[DEPRECATION WARNING] `SquadHeadOutput` is deprecated and will be removed in v4.53. "
"Please use model-specific class, e.g. `XLMSquadHeadOutput`."
)


class SQuADHead(nn.Module):
r"""
Expand All @@ -5594,6 +5612,11 @@ def __init__(self, config):
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)

logger.warning_once(
"[DEPRECATION WARNING] `SQuADHead` is deprecated and will be removed in v4.53. "
"Please use model-specific class, e.g. `XLMSQuADHead`."
)

@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
def forward(
self,
Expand Down Expand Up @@ -5747,6 +5770,11 @@ def __init__(self, config: PretrainedConfig):
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)

logger.warning_once(
"[DEPRECATION WARNING] `SequenceSummary` is deprecated and will be removed in v4.53. "
"Please use model-specific class, e.g. `XLMSequenceSummary`."
)

def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
) -> torch.FloatTensor:
Expand Down
108 changes: 104 additions & 4 deletions src/transformers/models/clvp/modeling_clvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import copy
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...activations import ACT2FN, get_activation
from ...generation import GenerationConfig, GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
Expand All @@ -34,7 +34,7 @@
BaseModelOutputWithPooling,
CausalLMOutputWithCrossAttentions,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, isin_mps_friendly
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -499,6 +499,106 @@ def forward(
return outputs


# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp
class ClvpSequenceSummary(nn.Module):
r"""
Compute a single vector summary of a sequence hidden states.

Args:
config ([`ClvpConfig`]):
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
config class of your model for the default values it uses):

- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:

- `"last"` -- Take the last token hidden state (like XLNet)
- `"first"` -- Take the first token hidden state (like Bert)
- `"mean"` -- Take the mean of all tokens hidden states
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- `"attn"` -- Not implemented now, use multi-head attention

- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
(otherwise to `config.hidden_size`).
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
another string or `None` will add no activation.
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
"""

def __init__(self, config: ClvpConfig):
super().__init__()

self.summary_type = getattr(config, "summary_type", "last")
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError

self.summary = nn.Identity()
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)

activation_string = getattr(config, "summary_activation", None)
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()

self.first_dropout = nn.Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)

self.last_dropout = nn.Identity()
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)

def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
) -> torch.FloatTensor:
"""
Compute a single vector summary of a sequence hidden states.

Args:
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

Returns:
`torch.FloatTensor`: The summary of the sequence hidden states.
"""
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == "mean":
output = hidden_states.mean(dim=1)
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = torch.full_like(
hidden_states[..., :1, :],
hidden_states.shape[-2] - 1,
dtype=torch.long,
)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError

output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output)

return output


# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP
class ClvpDecoderMLP(nn.Module):
def __init__(self, intermediate_size, config):
Expand Down Expand Up @@ -884,7 +984,7 @@ def __init__(self, config: ClvpConfig):
self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None
self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)])

self.sequence_summary = SequenceSummary(config)
self.sequence_summary = ClvpSequenceSummary(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
Expand Down
Loading