Skip to content

Commit dea8e9a

Browse files
committed
Refactor and initial test fixes
1 parent 0e670b3 commit dea8e9a

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

src/speculators/utils/transformers_utils.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
from pathlib import Path
17+
from typing import cast
1718

1819
import torch
1920
from huggingface_hub import snapshot_download
@@ -248,10 +249,7 @@ def load_model_config(
248249
return model.config # type: ignore[attr-defined]
249250

250251
if not isinstance(model, (str, os.PathLike)):
251-
raise TypeError(
252-
"Expected model to be a string, Path, or PreTrainedModel, "
253-
f"got {type(model)}"
254-
)
252+
raise TypeError(f"Expected model to be a string or Path, got {type(model)}")
255253

256254
try:
257255
logger.debug(f"Loading config with AutoConfig from: {model}")
@@ -271,6 +269,12 @@ def load_model_config(
271269

272270
def load_model_checkpoint_config_dict(
273271
config: str | os.PathLike | PretrainedConfig | PreTrainedModel | dict,
272+
cache_dir: str | Path | None = None,
273+
force_download: bool = False,
274+
local_files_only: bool = False,
275+
token: str | bool | None = None,
276+
revision: str | None = None,
277+
**kwargs,
274278
) -> dict:
275279
"""
276280
Load model configuration as dictionary from various sources.
@@ -279,6 +283,12 @@ def load_model_checkpoint_config_dict(
279283
or extracting from existing model/config instances.
280284
281285
:param config: Local path, PretrainedConfig, PreTrainedModel, or dict
286+
:param cache_dir: Directory to cache downloaded files
287+
:param force_download: Whether to force re-download existing files
288+
:param local_files_only: Only use cached files without downloading
289+
:param token: Authentication token for private models
290+
:param revision: Model revision (branch, tag, or commit hash)
291+
:param kwargs: Additional arguments for `check_download_model_config`
282292
:return: Configuration dictionary
283293
:raises TypeError: If config is not a supported type
284294
:raises FileNotFoundError: If config.json cannot be found
@@ -301,7 +311,18 @@ def load_model_checkpoint_config_dict(
301311
f"or PretrainedConfig, got {type(config)}"
302312
)
303313

304-
path = Path(config)
314+
path = cast(
315+
"Path",
316+
check_download_model_config(
317+
config,
318+
cache_dir=cache_dir,
319+
force_download=force_download,
320+
local_files_only=local_files_only,
321+
token=token,
322+
revision=revision,
323+
**kwargs,
324+
),
325+
)
305326

306327
if path.is_dir():
307328
path = path / "config.json"
@@ -378,7 +399,7 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
378399
Searches for weight files in various formats (.bin, .safetensors) through
379400
automatic detection of different organization patterns.
380401
381-
:param path: Local checkpoint directory, index file, or weight file path
402+
:param path: HF ID, local checkpoint directory, index file, or weight file path
382403
:return: List of paths to weight files
383404
:raises TypeError: If path is not a string or Path-like object
384405
:raises FileNotFoundError: If path doesn't exist or no weight files found
@@ -416,14 +437,26 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
416437

417438
def load_model_checkpoint_state_dict(
418439
model: str | os.PathLike | PreTrainedModel | nn.Module,
440+
cache_dir: str | Path | None = None,
441+
force_download: bool = False,
442+
local_files_only: bool = False,
443+
token: str | bool | None = None,
444+
revision: str | None = None,
445+
**kwargs,
419446
) -> dict[str, Tensor]:
420447
"""
421448
Load complete model state dictionary from various sources.
422449
423450
Supports loading from model instances, local checkpoint directories,
424451
or individual weight files with automatic format detection.
425452
426-
:param model: Model instance, checkpoint directory, or weight file path
453+
:param model: Model instance, HF ID, checkpoint directory, or weight file path
454+
:param cache_dir: Directory to cache downloaded files
455+
:param force_download: Whether to force re-download existing files
456+
:param local_files_only: Only use cached files without downloading
457+
:param token: Authentication token for private models
458+
:param revision: Model revision (branch, tag, or commit hash)
459+
:param kwargs: Additional arguments for `check_download_model_checkpoint`
427460
:return: Dictionary mapping parameter names to tensors
428461
:raises ValueError: If unsupported file format is encountered
429462
"""
@@ -432,7 +465,17 @@ def load_model_checkpoint_state_dict(
432465
return model.state_dict() # type: ignore[union-attr]
433466

434467
logger.debug(f"Loading model weights from: {model}")
435-
weight_files = load_model_checkpoint_weight_files(model)
468+
weight_files = load_model_checkpoint_weight_files(
469+
check_download_model_checkpoint(
470+
model,
471+
cache_dir=cache_dir,
472+
force_download=force_download,
473+
local_files_only=local_files_only,
474+
token=token,
475+
revision=revision,
476+
**kwargs,
477+
)
478+
)
436479

437480
state_dict = {}
438481

tests/unit/utils/test_transformers_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,7 @@ def test_invalid_invocation(self, mock_auto_config):
472472
"""Test with invalid input types and missing configs."""
473473
with pytest.raises(TypeError) as type_exc:
474474
load_model_config(123) # type: ignore[arg-type]
475-
assert "Expected model to be a string, Path, or PreTrainedModel" in str(
476-
type_exc.value
477-
)
475+
assert "Expected model to be a string or Path" in str(type_exc.value)
478476

479477
mock_auto_config.from_pretrained.side_effect = ValueError("Config not found")
480478
with pytest.raises(FileNotFoundError) as file_exc:

0 commit comments

Comments
 (0)