Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
from model2vec.distill.inference import PCADimType, PoolingType, create_embeddings, post_process_embeddings
from model2vec.distill.utils import select_optimal_device
from model2vec.model import StaticModel
from model2vec.quantization import DType, quantize_embeddings
Expand All @@ -33,6 +33,7 @@ def distill_from_model(
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
vocabulary_quantization: int | None = None,
pooling: PoolingType = PoolingType.MEAN,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -59,7 +60,12 @@ def distill_from_model(
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:return: A StaticModel
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
'first': use the first token's hidden state ([CLS] token in BERT-style models).
'pooler': use the pooler output (if available). This is often a non-linear projection of the [CLS] token.
:return: A StaticModel.
:raises: ValueError if the vocabulary is empty after preprocessing.

"""
Expand Down Expand Up @@ -114,7 +120,11 @@ def distill_from_model(

# Create the embeddings
embeddings = create_embeddings(
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
tokenized=token_ids,
model=model,
device=device,
pad_token_id=tokenizer.get_vocab()[pad_token],
pooling=pooling,
)

if vocabulary_quantization is not None:
Expand Down Expand Up @@ -142,6 +152,7 @@ def distill_from_model(
"hidden_dim": embeddings.shape[1],
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
"normalize": True,
"pooling": pooling,
}

if os.path.exists(model_name):
Expand Down Expand Up @@ -226,6 +237,7 @@ def distill(
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
vocabulary_quantization: int | None = None,
pooling: PoolingType = PoolingType.MEAN,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -251,6 +263,11 @@ def distill(
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
'first': use the first token's hidden state ([CLS] token in BERT-style models).
'pooler': use the pooler output (if available). This is often a non-linear projection of the [CLS] token.
:return: A StaticModel

"""
Expand All @@ -272,4 +289,5 @@ def distill(
quantize_to=quantize_to,
use_subword=use_subword,
vocabulary_quantization=vocabulary_quantization,
pooling=pooling,
)
146 changes: 120 additions & 26 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import inspect
import logging
from enum import Enum
from pathlib import Path
from typing import Literal, Protocol, Union
from typing import Literal, Union

import numpy as np
import torch
Expand All @@ -16,23 +17,37 @@

logger = logging.getLogger(__name__)


PathLike = Union[Path, str]
PCADimType = Union[int, None, float, Literal["auto"]]


_DEFAULT_BATCH_SIZE = 256


class ModulewithWeights(Protocol):
weight: torch.nn.Parameter
class PoolingType(str, Enum):
"""
Pooling strategies for embedding creation.

- MEAN: masked mean over all tokens.
- LAST: last non-padding token (often EOS, common in decoder-style models).
- FIRST: first token hidden state (position 0). In BERT-style encoders,
this corresponds to the [CLS] token representation.
- POOLER: use the model's `pooler_output`. In BERT-like models this is
computed as the hidden state at [CLS], passed through a learned
dense layer + activation. Not all models provide this.
"""

MEAN = "mean"
LAST = "last"
FIRST = "first"
POOLER = "pooler"


def create_embeddings(
model: PreTrainedModel,
tokenized: list[list[int]],
device: str,
pad_token_id: int,
pooling: PoolingType = PoolingType.MEAN,
) -> np.ndarray:
"""
Create output embeddings for a bunch of tokens using a pretrained model.
Expand All @@ -44,9 +59,11 @@ def create_embeddings(
:param tokenized: All tokenized tokens.
:param device: The torch device to use.
:param pad_token_id: The pad token id. Used to pad sequences.
:param pooling: The pooling strategy to use.
:return: The output embeddings.
:raises ValueError: If the pooling strategy is unknown.
"""
model = model.to(device) # type: ignore # Transformers error
model = model.to(device).eval() # type: ignore # Transformers error

out_weights: np.ndarray
intermediate_weights: list[np.ndarray] = []
Expand All @@ -62,56 +79,133 @@ def create_embeddings(
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")

for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]
batch_list = sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
batch = [torch.tensor(x, dtype=torch.long) for x in batch_list]

encoded = {}
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id

# Create attention mask by using the lengths of each sequence
seq_len = encoded["input_ids"].size(1)
batch_lengths = torch.tensor([len(x) for x in batch_list], device=encoded["input_ids"].device)
token_positions = torch.arange(seq_len, device=encoded["input_ids"].device)
# Mark padding tokens with 0, and non-padding tokens with 1
attention_mask = token_positions.unsqueeze(0) < batch_lengths.unsqueeze(1)
encoded["attention_mask"] = attention_mask.to(dtype=torch.long)

if add_token_type_ids:
# Add token_type_ids for models that support it
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])

out = _encode_mean_using_model(model, encoded)
if pooling == PoolingType.MEAN:
out = _encode_mean_with_model(model, encoded)
elif pooling == PoolingType.LAST:
out = _encode_last_with_model(model, encoded)
elif pooling == PoolingType.FIRST:
out = _encode_first_with_model(model, encoded)
elif pooling == PoolingType.POOLER:
out = _encode_pooler_with_model(model, encoded)
else:
raise ValueError(f"Unknown pooling: {pooling}")

intermediate_weights.extend(out.numpy())
pbar.update(len(batch))

# Sort the output back to the original order
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
out_weights = np.stack(intermediate_weights)

out_weights = np.nan_to_num(out_weights)

return out_weights


@torch.no_grad()
def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
def _encode_with_model(
model: PreTrainedModel, encodings: dict[str, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor]]:
"""
Encode a batch of tokens using a model.

Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
So detection of these is necessary.
Move inputs to the model device, run a forward pass, and standardize dtypes.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The mean of the output for each token.
:return: a tuple consisting of:
- hidden: last_hidden_state
- pooler: pooler_output if present, else None
- encodings_on_device: the device-moved encodings (for masks)
"""
encodings = {k: v.to(model.device) for k, v in encodings.items()}
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # False positive
encodings_on_device = {k: v.to(model.device) for k, v in encodings.items()}
outputs: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings_on_device)
hidden: torch.Tensor = outputs.last_hidden_state # type: ignore # False positive
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
# See here: https://github.com/numpy/numpy/issues/19808
if out.dtype == torch.bfloat16:
out = out.float()
if hidden.dtype == torch.bfloat16:
hidden = hidden.float()
pooler = getattr(outputs, "pooler_output", None)
if pooler is not None and pooler.dtype == torch.bfloat16:
pooler = pooler.float()
return hidden, pooler, encodings_on_device


@torch.inference_mode()
def _encode_mean_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using mean pooling.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The mean of the output for each token.
"""
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
# Take the mean by averaging over the attention mask.
mask = encodings["attention_mask"].cpu().float()
mask /= mask.sum(1)[:, None]
mask = encodings_on_device["attention_mask"].cpu().float()
lengths = mask.sum(1, keepdim=True).clamp_min_(1.0)
mask = mask / lengths
return torch.bmm(mask.to(hidden.device)[:, None, :], hidden).squeeze(1).cpu()


result = torch.bmm(mask[:, None, :].float(), out).squeeze(1)
@torch.inference_mode()
def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using last token pooling.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The last hidden state for each token.
"""
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
mask = encodings_on_device["attention_mask"].bool()
last_idx = (mask.sum(dim=1) - 1).clamp_min(0).long()
batch_indices = torch.arange(hidden.size(0), device=hidden.device)
return hidden[batch_indices, last_idx, :].cpu()


@torch.inference_mode()
def _encode_first_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using first token (CLS) pooling.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The first token representation for each token.
"""
hidden, _, _ = _encode_with_model(model, encodings)
return hidden[:, 0, :].cpu()

return result

@torch.inference_mode()
def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using pooler output.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The pooler output for each token.
:raises ValueError: If the model does not return pooler_output.
"""
_, pooler, _ = _encode_with_model(model, encodings)
if pooler is None:
raise ValueError("POOLER pooling requested, but model did not return pooler_output.")
return pooler.cpu()


def post_process_embeddings(
Expand Down
31 changes: 16 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,29 +59,30 @@ def mock_transformer() -> PreTrainedModel:
"""Create a mock transformer model."""

class MockPreTrainedModel:
def __init__(self) -> None:
def __init__(self, dim: int = 768, with_pooler: bool = True, pooler_value: float = 7.0) -> None:
self.device = "cpu"
self.name_or_path = "mock-model"
self.dim = dim
self.with_pooler = with_pooler
self.pooler_value = pooler_value

def to(self, device: str) -> MockPreTrainedModel:
self.device = device
return self

def eval(self) -> MockPreTrainedModel:
return self

def forward(self, *args: Any, **kwargs: Any) -> Any:
# Simulate a last_hidden_state output for a transformer model
batch_size, seq_length = kwargs["input_ids"].shape
# Return a tensor of shape (batch_size, seq_length, 768)
return type(
"BaseModelOutputWithPoolingAndCrossAttentions",
(object,),
{
"last_hidden_state": torch.rand(batch_size, seq_length, 768) # Simulate 768 hidden units
},
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Simply call the forward method to simulate the same behavior as transformers models
return self.forward(*args, **kwargs)
input_ids = kwargs["input_ids"]
B, T = input_ids.shape
hidden = torch.arange(T, dtype=torch.float32, device=self.device).repeat(B, self.dim, 1).transpose(1, 2)
out = {"last_hidden_state": hidden}
if self.with_pooler:
out["pooler_output"] = torch.full((B, self.dim), self.pooler_value, device=self.device)
return type("BaseModelOutputWithPoolingAndCrossAttentions", (object,), out)()

__call__ = forward

return cast(PreTrainedModel, MockPreTrainedModel())

Expand Down
Loading