Skip to content

Commit 23fe6c8

Browse files
committed
Add unit tests
1 parent 3daba9c commit 23fe6c8

File tree

5 files changed

+706
-9
lines changed

5 files changed

+706
-9
lines changed

src/speculators/convert/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- HASS: https://github.com/HArmonizedSS/HASS
1010
"""
1111

12+
from .converters import SpeculatorConverter
1213
from .entrypoints import convert_model
1314

14-
__all__ = ["convert_model"]
15+
__all__ = ["SpeculatorConverter", "convert_model"]

src/speculators/convert/converters/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"""Generic type variable for speculator models"""
3131

3232

33-
class SpeculatorConverter(ABC, Generic[ConfigT, ModelT], RegistryMixin):
33+
class SpeculatorConverter(ABC, RegistryMixin, Generic[ConfigT, ModelT]):
3434
"""
3535
Abstract base converter for transforming external checkpoints to Speculators format.
3636
@@ -55,8 +55,8 @@ class SpeculatorConverter(ABC, Generic[ConfigT, ModelT], RegistryMixin):
5555
def resolve_converter(
5656
cls,
5757
algorithm: str,
58-
model: Path | PreTrainedModel | nn.Module,
59-
config: Path | PretrainedConfig | dict,
58+
model: str | Path | PreTrainedModel | nn.Module,
59+
config: str | Path | PretrainedConfig | dict,
6060
verifier: str | os.PathLike | PreTrainedModel | None = None,
6161
**kwargs,
6262
) -> type[SpeculatorConverter]:
@@ -104,8 +104,8 @@ def resolve_converter(
104104
@abstractmethod
105105
def is_supported(
106106
cls,
107-
model: Path | PreTrainedModel | nn.Module,
108-
config: Path | PretrainedConfig | dict,
107+
model: str | Path | PreTrainedModel | nn.Module,
108+
config: str | Path | PretrainedConfig | dict,
109109
verifier: str | os.PathLike | PreTrainedModel | None = None,
110110
**kwargs,
111111
) -> bool:
@@ -122,8 +122,8 @@ def is_supported(
122122

123123
def __init__(
124124
self,
125-
model: Path | PreTrainedModel | nn.Module,
126-
config: Path | PretrainedConfig | dict,
125+
model: str | Path | PreTrainedModel | nn.Module,
126+
config: str | Path | PretrainedConfig | dict,
127127
verifier: str | os.PathLike | PreTrainedModel | None,
128128
):
129129
"""
@@ -135,7 +135,7 @@ def __init__(
135135
:raises ValueError: If model or config is None or empty
136136
"""
137137

138-
if model is None or config is None:
138+
if model is None or config is None or model == "" or config == "":
139139
raise ValueError(
140140
f"Model and config paths must be provided, got {model}, {config}"
141141
)

tests/unit/convert/converters/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)