Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 33 additions & 62 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import logging
import os
import re
from typing import Optional, cast
from typing import cast

import numpy as np
from huggingface_hub.hf_api import model_info
from huggingface_hub import model_info
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
Expand All @@ -16,6 +16,7 @@
from model2vec.model import StaticModel
from model2vec.quantization import DType, quantize_embeddings
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
from model2vec.tokenizer.tokenizer import _patch_tokenizer
from model2vec.vocabulary_quantization import quantize_vocabulary

logger = logging.getLogger(__name__)
Expand All @@ -27,11 +28,10 @@ def distill_from_model(
vocabulary: list[str] | None = None,
device: str | None = None,
pca_dims: PCADimType = 256,
apply_zipf: bool | None = None,
sif_coefficient: float | None = 1e-4,
token_remove_pattern: str | None = r"\[unused\d+\]",
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
lower_case: bool = True,
vocabulary_quantization: int | None = None,
) -> StaticModel:
"""
Expand All @@ -50,26 +50,21 @@ def distill_from_model(
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:param lower_case: If this is set, all tokens in the model vocabulary will be converted to lowercase, and
a lowercase normalizer will be inserted. This almost always improves performance.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:return: A StaticModel
:raises: ValueError if the vocabulary is empty after preprocessing.

"""
if use_subword is not None:
logger.warning(
"The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
)
quantize_to = DType(quantize_to)
backend_tokenizer = tokenizer.backend_tokenizer
sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)
sif_coefficient, token_remove_regex = _validate_parameters(sif_coefficient, token_remove_pattern)

if vocabulary is None:
vocabulary = []
Expand All @@ -78,44 +73,40 @@ def distill_from_model(

n_tokens_before = len(vocabulary)
# Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
all_tokens, backend_tokenizer = clean_and_create_vocabulary(
tokenizer, vocabulary, token_remove_regex=token_remove_regex
tokens, backend_tokenizer = clean_and_create_vocabulary(
tokenizer, vocabulary, token_remove_regex=token_remove_regex, lower_case=lower_case
)
n_tokens_after = len([token for token in all_tokens if not token.is_internal])
n_tokens_after = len([token for token in tokens if not token.is_internal])
if n_tokens_before:
logger.info(
f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
)

if not all_tokens:
if not tokens:
raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")

unk_token = cast(Optional[str], tokenizer.special_tokens_map.get("unk_token"))
pad_token = cast(Optional[str], tokenizer.special_tokens_map.get("pad_token"))

# Weird if to satsify mypy
if pad_token is None:
if unk_token is not None:
pad_token = unk_token
logger.warning(
"The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token."
)
else:
pad_token = unk_token or all_tokens[0].form
logger.warning(
"The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token."
)

# Replace the vocabulary in the tokenizer with the new vocabulary.
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
backend_tokenizer = replace_vocabulary(backend_tokenizer, tokens)

logger.info(f"Creating embeddings for {len(tokens)} tokens")
# Convert tokens to IDs
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
m = _patch_tokenizer(tokenizer=tokenizer, lower_case=False)
bb = m.to_tokenizer()

token_ids = turn_tokens_into_ids(tokens, bb)

# Create the embeddings
embeddings = create_embeddings(
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
)
pad_token = cast(str | None, tokenizer.special_tokens_map.get("pad_token", None))
vocab = tokenizer.get_vocab()
if pad_token is None:
sep_token = cast(str | None, tokenizer.special_tokens_map.get("sep_token", None))
if sep_token is None:
pad_token_id = 0
else:
pad_token_id = vocab[sep_token]
else:
pad_token_id = vocab[pad_token]
embeddings = create_embeddings(tokenized=token_ids, model=model, device=device, pad_token_id=pad_token_id)

if vocabulary_quantization is not None:
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
Expand All @@ -137,7 +128,6 @@ def distill_from_model(
"architectures": ["StaticModel"],
"tokenizer_name": model_name,
"apply_pca": pca_dims,
"apply_zipf": apply_zipf,
"sif_coefficient": sif_coefficient,
"hidden_dim": embeddings.shape[1],
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
Expand Down Expand Up @@ -171,35 +161,19 @@ def distill_from_model(


def _validate_parameters(
apply_zipf: bool | None,
sif_coefficient: float | None,
token_remove_pattern: str | None,
) -> tuple[float | None, re.Pattern | None]:
"""
Validate the parameters passed to the distillation function.

:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:return: The SIF coefficient to use.
:raises: ValueError if the regex can't be compiled.

"""
if apply_zipf is not None:
logger.warning(
"The `apply_zipf` parameter is deprecated and will be removed in the next release. "
"Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
"no weighting is applied."
)
if apply_zipf and sif_coefficient is None:
logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
sif_coefficient = 1e-4
elif not apply_zipf:
logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
sif_coefficient = None

if sif_coefficient is not None:
if not 0 < sif_coefficient < 1.0:
raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")
Expand All @@ -219,12 +193,11 @@ def distill(
vocabulary: list[str] | None = None,
device: str | None = None,
pca_dims: PCADimType = 256,
apply_zipf: bool | None = None,
sif_coefficient: float | None = 1e-4,
token_remove_pattern: str | None = r"\[unused\d+\]",
trust_remote_code: bool = False,
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
lower_case: bool = True,
vocabulary_quantization: int | None = None,
) -> StaticModel:
"""
Expand All @@ -242,14 +215,13 @@ def distill(
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:param lower_case: If this is set, all tokens in the model vocabulary will be converted to lowercase, and
a lowercase normalizer will be inserted. This almost always improves performance.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:return: A StaticModel

Expand All @@ -266,10 +238,9 @@ def distill(
vocabulary=vocabulary,
device=device,
pca_dims=pca_dims,
apply_zipf=apply_zipf,
token_remove_pattern=token_remove_pattern,
sif_coefficient=sif_coefficient,
quantize_to=quantize_to,
use_subword=use_subword,
lower_case=lower_case,
vocabulary_quantization=vocabulary_quantization,
)
3 changes: 2 additions & 1 deletion model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
"""
encodings = {k: v.to(model.device) for k, v in encodings.items()}
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # False positive
assert encoded.last_hidden_state is not None
out: torch.Tensor = encoded.last_hidden_state.cpu()
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
# See here: https://github.com/numpy/numpy/issues/19808
Expand Down
3 changes: 1 addition & 2 deletions model2vec/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from model2vec.tokenizer.tokenizer import (
clean_and_create_vocabulary,
create_tokenizer,
replace_vocabulary,
turn_tokens_into_ids,
)

__all__ = ["clean_and_create_vocabulary", "create_tokenizer", "turn_tokens_into_ids", "replace_vocabulary"]
__all__ = ["clean_and_create_vocabulary", "turn_tokens_into_ids", "replace_vocabulary"]
5 changes: 4 additions & 1 deletion model2vec/tokenizer/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
class Token:
"""A class to represent a token."""

# The surface form: used for featurizing
form: str
# The normalized and pretokenized form of the token
# The normalized form: preprocessed by the new tokenizer
normalized_form: str
# Whether the word is a continuing subword.
is_subword: bool
# Whether the token is internal to the model.
is_internal: bool
# Whether the token is a multiword token
is_multiword: bool = False
43 changes: 0 additions & 43 deletions model2vec/tokenizer/model.py

This file was deleted.

42 changes: 0 additions & 42 deletions model2vec/tokenizer/normalizer.py

This file was deleted.

57 changes: 0 additions & 57 deletions model2vec/tokenizer/pretokenizer.py

This file was deleted.

Loading