Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
fail-fast: false

steps:
Expand Down
6 changes: 3 additions & 3 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import re
from typing import Optional, cast
from typing import cast

import numpy as np
from huggingface_hub.hf_api import model_info
Expand Down Expand Up @@ -87,8 +87,8 @@ def distill_from_model(
if not all_tokens:
raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")

unk_token = cast(Optional[str], tokenizer.special_tokens_map.get("unk_token"))
pad_token = cast(Optional[str], tokenizer.special_tokens_map.get("pad_token"))
unk_token = cast(str | None, tokenizer.special_tokens_map.get("unk_token"))
pad_token = cast(str | None, tokenizer.special_tokens_map.get("pad_token"))

# Weird if to satsify mypy
if pad_token is None:
Expand Down
7 changes: 3 additions & 4 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import inspect
import logging
from enum import Enum
from pathlib import Path
from typing import Literal, Union
from typing import Literal

import numpy as np
import torch
Expand All @@ -17,8 +16,8 @@

logger = logging.getLogger(__name__)

PathLike = Union[Path, str]
PCADimType = Union[int, None, float, Literal["auto"]]
PathLike = Path | str
PCADimType = int | None | float | Literal["auto"]

_DEFAULT_BATCH_SIZE = 256

Expand Down
7 changes: 4 additions & 3 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import re
from collections.abc import Sequence
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Sequence, TypeVar, Union, cast
from typing import TypeVar, cast

import huggingface_hub
import numpy as np
Expand Down Expand Up @@ -293,14 +294,14 @@ def evaluate_single_or_multi_label(
"""
if _is_multi_label_shaped(y):
# Cast because the type checker doesn't understand that y is a list of lists.
y = cast(Union[list[list[str]], list[list[int]]], y)
y = cast(list[list[str]] | list[list[int]], y)
classes = sorted(set([label for labels in y for label in labels]))
mlb = MultiLabelBinarizer(classes=classes)
y_transformed = mlb.fit_transform(y)
predictions_transformed = mlb.transform(predictions)
else:
if all(isinstance(label, (str, int)) for label in y):
y = cast(Union[list[str], list[int]], y)
y = cast(list[str] | list[int], y)
classes = sorted(set(y))
y_transformed = np.array(y)
predictions_transformed = np.array(predictions)
Expand Down
5 changes: 3 additions & 2 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import math
import os
from collections.abc import Iterator, Sequence
from logging import getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Iterator, Sequence, Union, overload
from typing import Any, overload

import numpy as np
from joblib import delayed
Expand All @@ -15,7 +16,7 @@
from model2vec.quantization import DType, quantize_and_reduce_dim
from model2vec.utils import ProgressParallel

PathLike = Union[Path, str]
PathLike = Path | str

logger = getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions model2vec/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import re
from typing import Any, Optional, cast
from typing import Any, cast

from tokenizers import Tokenizer
from tokenizers.normalizers import Normalizer
Expand Down Expand Up @@ -387,8 +387,8 @@ def create_tokenizer(
:param token_remove_regex: The regex to use to remove tokens from the vocabulary.
:return: The created tokenizer.
"""
unk_token = cast(Optional[str], tokenizer.special_tokens_map.get("unk_token"))
pad_token = cast(Optional[str], tokenizer.special_tokens_map.get("pad_token"))
unk_token = cast(str | None, tokenizer.special_tokens_map.get("unk_token"))
pad_token = cast(str | None, tokenizer.special_tokens_map.get("pad_token"))
cleaned_vocabulary, backend_tokenizer = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex)
new_tokenizer = replace_vocabulary(backend_tokenizer, cleaned_vocabulary, unk_token, pad_token)

Expand Down
4 changes: 2 additions & 2 deletions model2vec/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import logging
import re
from collections.abc import Iterator
from importlib import import_module
from importlib.metadata import metadata
from typing import Any, Iterator, Protocol
from typing import Any, Protocol

import numpy as np
from joblib import Parallel
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "model2vec"
description = "Fast State-of-the-Art Static Embeddings"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
requires-python = ">=3.9"
requires-python = ">=3.10"
authors = [{ name = "Stéphan Tulkens", email = "[email protected]"}, {name = "Thomas van Dongen", email = "[email protected]"}]
dynamic = ["version"]

Expand All @@ -15,7 +15,6 @@ classifiers = [
"Topic :: Software Development :: Libraries",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down
Loading