Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from model2vec.distill.inference import PCADimType, PoolingType, create_embeddings, post_process_embeddings
from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings
from model2vec.distill.utils import select_optimal_device
from model2vec.model import StaticModel
from model2vec.quantization import DType, quantize_embeddings
Expand All @@ -31,7 +31,7 @@ def distill_from_model(
token_remove_pattern: str | None = r"\[unused\d+\]",
quantize_to: DType | str = DType.Float16,
vocabulary_quantization: int | None = None,
pooling: PoolingType = PoolingType.MEAN,
pooling: PoolingMode = PoolingMode.MEAN,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -55,7 +55,7 @@ def distill_from_model(
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
:param pooling: The pooling mode to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
'first': use the first token's hidden state ([CLS] token in BERT-style models).
Expand Down Expand Up @@ -209,7 +209,7 @@ def distill(
trust_remote_code: bool = False,
quantize_to: DType | str = DType.Float16,
vocabulary_quantization: int | None = None,
pooling: PoolingType = PoolingType.MEAN,
pooling: PoolingMode = PoolingMode.MEAN,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -232,7 +232,7 @@ def distill(
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
:param pooling: The pooling mode to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
'first': use the first token's hidden state ([CLS] token in BERT-style models).
Expand Down
18 changes: 9 additions & 9 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
_DEFAULT_BATCH_SIZE = 256


class PoolingType(str, Enum):
class PoolingMode(str, Enum):
"""
Pooling strategies for embedding creation.
Pooling modes for embedding creation.

- MEAN: masked mean over all tokens.
- LAST: last non-padding token (often EOS, common in decoder-style models).
Expand All @@ -47,7 +47,7 @@ def create_embeddings(
tokenized: list[list[int]],
device: str,
pad_token_id: int,
pooling: PoolingType = PoolingType.MEAN,
pooling: PoolingMode = PoolingMode.MEAN,
) -> np.ndarray:
"""
Create output embeddings for a bunch of tokens using a pretrained model.
Expand All @@ -59,9 +59,9 @@ def create_embeddings(
:param tokenized: All tokenized tokens.
:param device: The torch device to use.
:param pad_token_id: The pad token id. Used to pad sequences.
:param pooling: The pooling strategy to use.
:param pooling: The pooling mode to use.
:return: The output embeddings.
:raises ValueError: If the pooling strategy is unknown.
:raises ValueError: If the pooling mode is unknown.
"""
model = model.to(device).eval() # type: ignore # Transformers error

Expand Down Expand Up @@ -97,13 +97,13 @@ def create_embeddings(
# Add token_type_ids for models that support it
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])

if pooling == PoolingType.MEAN:
if pooling == PoolingMode.MEAN:
out = _encode_mean_with_model(model, encoded)
elif pooling == PoolingType.LAST:
elif pooling == PoolingMode.LAST:
out = _encode_last_with_model(model, encoded)
elif pooling == PoolingType.FIRST:
elif pooling == PoolingMode.FIRST:
out = _encode_first_with_model(model, encoded)
elif pooling == PoolingType.POOLER:
elif pooling == PoolingMode.POOLER:
out = _encode_pooler_with_model(model, encoded)
else:
raise ValueError(f"Unknown pooling: {pooling}")
Expand Down
12 changes: 6 additions & 6 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers.modeling_utils import PreTrainedModel

from model2vec.distill.distillation import distill, distill_from_model
from model2vec.distill.inference import PoolingType, create_embeddings, post_process_embeddings
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
from model2vec.model import StaticModel
from model2vec.tokenizer import clean_and_create_vocabulary

Expand Down Expand Up @@ -260,10 +260,10 @@ def test_clean_and_create_vocabulary(
@pytest.mark.parametrize(
"pooling,with_pooler,expected_rows",
[
(PoolingType.MEAN, False, [1.0, 0.0]), # len=3: mean(0,1,2)=1; len=1: mean(0)=0
(PoolingType.LAST, False, [2.0, 0.0]), # last of 3: 2; last of 1: 0
(PoolingType.FIRST, False, [0.0, 0.0]), # first position: 0
(PoolingType.POOLER, True, [7.0, 7.0]), # pooler_output used
(PoolingMode.MEAN, False, [1.0, 0.0]), # len=3: mean(0,1,2)=1; len=1: mean(0)=0
(PoolingMode.LAST, False, [2.0, 0.0]), # last of 3: 2; last of 1: 0
(PoolingMode.FIRST, False, [0.0, 0.0]), # first position: 0
(PoolingMode.POOLER, True, [7.0, 7.0]), # pooler_output used
],
)
def test_pooling_strategies(mock_transformer, pooling, with_pooler, expected_rows) -> None:
Expand Down Expand Up @@ -292,5 +292,5 @@ def test_pooler_raises_without_pooler_output(mock_transformer) -> None:
tokenized=tokenized,
device="cpu",
pad_token_id=0,
pooling=PoolingType.POOLER,
pooling=PoolingMode.POOLER,
)