Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ def distill_from_model(
vocabulary: list[str] | None = None,
device: str | None = None,
pca_dims: PCADimType = 256,
apply_zipf: bool | None = None,
sif_coefficient: float | None = 1e-4,
token_remove_pattern: str | None = r"\[unused\d+\]",
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
vocabulary_quantization: int | None = None,
pooling: PoolingType = PoolingType.MEAN,
) -> StaticModel:
Expand All @@ -51,14 +49,11 @@ def distill_from_model(
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
Expand All @@ -69,13 +64,9 @@ def distill_from_model(
:raises: ValueError if the vocabulary is empty after preprocessing.

"""
if use_subword is not None:
logger.warning(
"The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
)
quantize_to = DType(quantize_to)
backend_tokenizer = tokenizer.backend_tokenizer
sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)
sif_coefficient, token_remove_regex = _validate_parameters(sif_coefficient, token_remove_pattern)

if vocabulary is None:
vocabulary = []
Expand Down Expand Up @@ -147,7 +138,6 @@ def distill_from_model(
"architectures": ["StaticModel"],
"tokenizer_name": model_name,
"apply_pca": pca_dims,
"apply_zipf": apply_zipf,
"sif_coefficient": sif_coefficient,
"hidden_dim": embeddings.shape[1],
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
Expand Down Expand Up @@ -182,35 +172,19 @@ def distill_from_model(


def _validate_parameters(
apply_zipf: bool | None,
sif_coefficient: float | None,
token_remove_pattern: str | None,
) -> tuple[float | None, re.Pattern | None]:
"""
Validate the parameters passed to the distillation function.

:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:return: The SIF coefficient to use.
:raises: ValueError if the regex can't be compiled.

"""
if apply_zipf is not None:
logger.warning(
"The `apply_zipf` parameter is deprecated and will be removed in the next release. "
"Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
"no weighting is applied."
)
if apply_zipf and sif_coefficient is None:
logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
sif_coefficient = 1e-4
elif not apply_zipf:
logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
sif_coefficient = None

if sif_coefficient is not None:
if not 0 < sif_coefficient < 1.0:
raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")
Expand All @@ -230,12 +204,10 @@ def distill(
vocabulary: list[str] | None = None,
device: str | None = None,
pca_dims: PCADimType = 256,
apply_zipf: bool | None = None,
sif_coefficient: float | None = 1e-4,
token_remove_pattern: str | None = r"\[unused\d+\]",
trust_remote_code: bool = False,
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
vocabulary_quantization: int | None = None,
pooling: PoolingType = PoolingType.MEAN,
) -> StaticModel:
Expand All @@ -254,14 +226,11 @@ def distill(
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
Expand All @@ -283,11 +252,9 @@ def distill(
vocabulary=vocabulary,
device=device,
pca_dims=pca_dims,
apply_zipf=apply_zipf,
token_remove_pattern=token_remove_pattern,
sif_coefficient=sif_coefficient,
quantize_to=quantize_to,
use_subword=use_subword,
vocabulary_quantization=vocabulary_quantization,
pooling=pooling,
)
112 changes: 39 additions & 73 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import json
Expand All @@ -10,14 +11,10 @@
from transformers import BertTokenizerFast
from transformers.modeling_utils import PreTrainedModel

from model2vec.distill.distillation import (
clean_and_create_vocabulary,
distill,
distill_from_model,
post_process_embeddings,
)
from model2vec.distill.inference import PoolingType, create_embeddings
from model2vec.distill.distillation import distill, distill_from_model
from model2vec.distill.inference import PoolingType, create_embeddings, post_process_embeddings
from model2vec.model import StaticModel
from model2vec.tokenizer import clean_and_create_vocabulary

try:
# For huggingface_hub>=0.25.0
Expand All @@ -30,14 +27,14 @@


@pytest.mark.parametrize(
"vocabulary, pca_dims, apply_zipf",
"vocabulary, pca_dims, sif_coefficient",
[
(None, 256, True), # Output vocab with subwords, PCA applied
(["wordA", "wordB"], 4, False), # Custom vocab with subword, PCA applied
(None, "auto", False), # Subword, PCA set to 'auto'
(None, 1024, False), # Subword, PCA set to high number.
(None, None, True), # No PCA applied
(None, 0.9, True), # PCA as float applied
(None, 256, 1e-4), # Subword vocab, PCA applied, SIF on
(["wordA", "wordB"], 4, None), # Custom vocab, PCA applied, SIF off
(None, "auto", None), # Subword, PCA 'auto', SIF off
(None, 1024, None), # Subword, PCA set high, SIF off
(None, None, 1e-4), # No PCA, SIF on
(None, 0.9, 1e-4), # PCA as float (variance), SIF on
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
Expand All @@ -49,24 +46,20 @@ def test_distill_from_model(
mock_transformer: PreTrainedModel,
vocabulary: list[str] | None,
pca_dims: int | None,
apply_zipf: bool,
sif_coefficient: float | None,
) -> None:
"""Test distill function with different parameters."""
# Mock the return value of model_info to avoid calling the Hugging Face API
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})

# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
# mock_auto_tokenizer.return_value = mock_berttokenizer
mock_auto_model.return_value = mock_transformer

# Call the distill function with the parametrized inputs
static_model = distill_from_model(
model=mock_transformer,
tokenizer=mock_berttokenizer,
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
sif_coefficient=sif_coefficient,
token_remove_pattern=None,
)

Expand All @@ -75,7 +68,7 @@ def test_distill_from_model(
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
sif_coefficient=sif_coefficient,
token_remove_pattern=None,
)

Expand All @@ -94,11 +87,7 @@ def test_distill_removal_pattern(
mock_transformer: PreTrainedModel,
) -> None:
"""Test the removal pattern."""
# Mock the return value of model_info to avoid calling the Hugging Face API
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})

# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
# mock_auto_tokenizer.return_value = mock_berttokenizer
mock_auto_model.return_value = mock_transformer

# The vocab size is 30522, but we remove 998 tokens: [CLS], [SEP], and [MASK], and all [unused] tokens.
Expand All @@ -111,7 +100,6 @@ def test_distill_removal_pattern(
device="cpu",
token_remove_pattern=None,
)

assert len(static_model.embedding) == expected_vocab_size

# No tokens removed, nonsensical pattern
Expand All @@ -122,12 +110,11 @@ def test_distill_removal_pattern(
device="cpu",
token_remove_pattern="£££££££££££££££££",
)

assert len(static_model.embedding) == expected_vocab_size

# Weird pattern.
with pytest.raises(ValueError):
static_model = distill_from_model(
_ = distill_from_model(
model=mock_transformer,
tokenizer=mock_berttokenizer,
vocabulary=None,
Expand All @@ -137,19 +124,16 @@ def test_distill_removal_pattern(


@pytest.mark.parametrize(
"vocabulary, pca_dims, apply_zipf, sif_coefficient, expected_shape",
"vocabulary, pca_dims, sif_coefficient, expected_shape",
[
(None, 256, True, None, (29524, 256)), # Output vocab with subwords, PCA applied
(None, "auto", False, None, (29524, 768)), # Subword, PCA set to 'auto'
(None, "auto", True, 1e-4, (29524, 768)), # Subword, PCA set to 'auto'
(None, "auto", False, 1e-4, (29524, 768)), # Subword, PCA set to 'auto'
(None, "auto", True, 0, None), # Sif too low
(None, "auto", True, 1, None), # Sif too high
(None, "auto", False, 0, (29524, 768)), # Sif too low, but apply_zipf is False
(None, "auto", False, 1, (29524, 768)), # Sif too high, but apply_zipf is False
(None, 1024, False, None, (29524, 768)), # Subword, PCA set to high number.
(["wordA", "wordB"], 4, False, None, (29526, 4)), # Custom vocab with subword, PCA applied
(None, None, True, None, (29524, 768)), # No PCA applied
(None, 256, None, (29524, 256)), # PCA applied, SIF off
(None, "auto", None, (29524, 768)), # PCA 'auto', SIF off
(None, "auto", 1e-4, (29524, 768)), # PCA 'auto', SIF on
(None, "auto", 0, None), # invalid SIF (too low) -> raises
(None, "auto", 1, None), # invalid SIF (too high) -> raises
(None, 1024, None, (29524, 768)), # PCA set high (no reduction)
(["wordA", "wordB"], 4, None, (29526, 4)), # Custom vocab, PCA applied
(None, None, None, (29524, 768)), # No PCA, SIF off
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
Expand All @@ -160,47 +144,32 @@ def test_distill(
mock_transformer: PreTrainedModel,
vocabulary: list[str] | None,
pca_dims: int | None,
apply_zipf: bool,
sif_coefficient: float | None,
expected_shape: tuple[int, int],
expected_shape: tuple[int, int] | None,
) -> None:
"""Test distill function with different parameters."""
# Mock the return value of model_info to avoid calling the Hugging Face API
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})

# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
mock_auto_model.return_value = mock_transformer

model_name = "tests/data/test_tokenizer"

if (
apply_zipf is not None
and apply_zipf
and sif_coefficient is not None
and (sif_coefficient <= 0 or sif_coefficient >= 1)
):
if sif_coefficient is not None and (sif_coefficient <= 0 or sif_coefficient >= 1):
with pytest.raises(ValueError):
static_model = distill(
_ = distill(
model_name=model_name,
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
sif_coefficient=sif_coefficient,
)

else:
# Call the distill function with the parametrized inputs
static_model = distill(
model_name=model_name,
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
sif_coefficient=sif_coefficient,
)

# Assert the model is correctly generated
assert isinstance(static_model, StaticModel)
assert static_model.embedding.shape == expected_shape
assert "mock-model" in static_model.config["tokenizer_name"]
Expand All @@ -223,37 +192,36 @@ def test_missing_modelinfo(
"embeddings, pca_dims, sif_coefficient, expected_shape",
[
(rng.random((1000, 768)), 256, None, (1000, 256)), # PCA applied correctly
(rng.random((1000, 768)), None, None, (1000, 768)), # No PCA applied, dimensions remain unchanged
(rng.random((1000, 768)), 256, 1e-4, (1000, 256)), # PCA and Zipf applied
(rng.random((10, 768)), 256, 1e-4, (10, 768)), # PCA dims higher than vocab size, no PCA applied
(rng.random((1000, 768)), None, None, (1000, 768)), # No PCA applied, dimensions unchanged
(rng.random((1000, 768)), 256, 1e-4, (1000, 256)), # PCA and SIF applied
(rng.random((10, 768)), 256, 1e-4, (10, 768)), # PCA dims > vocab size, no PCA applied
],
)
def test__post_process_embeddings(
embeddings: np.ndarray, pca_dims: int, sif_coefficient: float | None, expected_shape: tuple[int, int]
embeddings: np.ndarray, pca_dims: int | float | None, sif_coefficient: float | None, expected_shape: tuple[int, int]
) -> None:
"""Test the _post_process_embeddings function."""
"""Test the post_process_embeddings function."""
original_embeddings = embeddings.copy() # Copy embeddings to compare later

# Test that the function raises an error if the PCA dims are larger than the number of dimensions
if pca_dims and pca_dims > embeddings.shape[1]:
with pytest.raises(ValueError):
post_process_embeddings(embeddings, pca_dims, None)
# If pca_dims > original dims and is an int, ensure function handles gracefully (warns, no exception)
if isinstance(pca_dims, int) and pca_dims and pca_dims > embeddings.shape[1]:
# The implementation logs a warning and skips reduction; no exception expected.
pass

processed_embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient)

# Assert the shape is correct
assert processed_embeddings.shape == expected_shape

# If Zipf weighting is applied compare the original and processed embeddings
# and check the weights are applied correctly
# If SIF weighting is applied and no PCA reduction, check weights are applied correctly
if sif_coefficient and pca_dims is None:
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
proba = inv_rank / np.sum(inv_rank)
sif_weights = (sif_coefficient / (sif_coefficient + proba))[:, None]

expected_zipf_embeddings = original_embeddings * sif_weights
assert np.allclose(processed_embeddings, expected_zipf_embeddings, rtol=1e-5), (
"Zipf weighting not applied correctly"
"SIF weighting not applied correctly"
)


Expand All @@ -275,7 +243,7 @@ def test_clean_and_create_vocabulary(
expected_warnings: list[str],
caplog: LogCaptureFixture,
) -> None:
"""Test the _clean_vocabulary function."""
"""Test the clean_and_create_vocabulary helper."""
with caplog.at_level("WARNING"):
tokens, _ = clean_and_create_vocabulary(mock_berttokenizer, added_tokens, None)

Expand All @@ -285,8 +253,6 @@ def test_clean_and_create_vocabulary(

# Check the warnings were logged as expected
logged_warnings = [record.message for record in caplog.records]

# Ensure the expected warnings contain expected keywords like 'Removed', 'duplicate', or 'empty'
for expected_warning in expected_warnings:
assert any(expected_warning in logged_warning for logged_warning in logged_warnings)

Expand Down