Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,11 @@ def _loading_helper(
language=metadata.get("language"),
)

# If no quantization or dimensionality reduction is requested,
# return the model as is.
if not any([vocabulary_quantization, quantize_to, dimensionality]):
return model

return quantize_model(
model=model,
vocabulary_quantization=vocabulary_quantization,
Expand Down
33 changes: 26 additions & 7 deletions model2vec/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class DType(str, Enum):
Int8 = "int8"


dtype_map = {
DType.Float16: np.float16,
DType.Float32: np.float32,
DType.Float64: np.float64,
DType.Int8: np.int8,
}


def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
"""
Quantize embeddings to a specified data type to reduce memory usage.
Expand All @@ -24,17 +32,28 @@ def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarra
:return: The quantized embeddings.
:raises ValueError: If the quantization type is not valid.
"""
if quantize_to == DType.Float16:
return embeddings.astype(np.float16)
elif quantize_to == DType.Float32:
return embeddings.astype(np.float32)
elif quantize_to == DType.Float64:
return embeddings.astype(np.float64)
mapped_dtype = dtype_map[quantize_to]
if embeddings.dtype == mapped_dtype:
# Don't do anything if they match
return embeddings

# Handle float types
if quantize_to in {DType.Float16, DType.Float32, DType.Float64}:
return embeddings.astype(mapped_dtype)
elif quantize_to == DType.Int8:
# Normalize to [-128, 127] range for int8
# We normalize to -127 to 127 to keep symmetry.
scale = np.max(np.abs(embeddings)) / 127.0
quantized = np.round(embeddings / scale).astype(np.int8)
# Turn into float16 to minimize memory usage during computation
# we copy once.
buf = embeddings.astype(np.float16, copy=True)
# Divide by the scale
np.divide(buf, scale, out=buf)
# Round to int, copy to the buffer
np.rint(buf, out=buf)
# Clip to int8 range and convert to int8
np.clip(buf, -127, 127, out=buf)
quantized = buf.astype(np.int8)
return quantized
else:
raise ValueError("Not a valid enum member of DType.")
Expand Down
7 changes: 7 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def test_load_pretrained_quantized(
assert loaded_model.embedding.dtype == np.float32
assert loaded_model.embedding.shape == mock_vectors.shape

# Load the model back from the same path
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float64")
# Assert that the loaded model has the same properties as the original one
assert loaded_model.embedding.dtype == np.float64
# Should not copy if same as original.
assert loaded_model.embedding is loaded_model.embedding


def test_load_pretrained_dim(
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
Expand Down
Loading