14
14
import json
15
15
import os
16
16
from pathlib import Path
17
+ from typing import cast
17
18
18
19
import torch
19
20
from huggingface_hub import snapshot_download
@@ -248,10 +249,7 @@ def load_model_config(
248
249
return model .config # type: ignore[attr-defined]
249
250
250
251
if not isinstance (model , (str , os .PathLike )):
251
- raise TypeError (
252
- "Expected model to be a string, Path, or PreTrainedModel, "
253
- f"got { type (model )} "
254
- )
252
+ raise TypeError (f"Expected model to be a string or Path, got { type (model )} " )
255
253
256
254
try :
257
255
logger .debug (f"Loading config with AutoConfig from: { model } " )
@@ -271,6 +269,12 @@ def load_model_config(
271
269
272
270
def load_model_checkpoint_config_dict (
273
271
config : str | os .PathLike | PretrainedConfig | PreTrainedModel | dict ,
272
+ cache_dir : str | Path | None = None ,
273
+ force_download : bool = False ,
274
+ local_files_only : bool = False ,
275
+ token : str | bool | None = None ,
276
+ revision : str | None = None ,
277
+ ** kwargs ,
274
278
) -> dict :
275
279
"""
276
280
Load model configuration as dictionary from various sources.
@@ -279,6 +283,12 @@ def load_model_checkpoint_config_dict(
279
283
or extracting from existing model/config instances.
280
284
281
285
:param config: Local path, PretrainedConfig, PreTrainedModel, or dict
286
+ :param cache_dir: Directory to cache downloaded files
287
+ :param force_download: Whether to force re-download existing files
288
+ :param local_files_only: Only use cached files without downloading
289
+ :param token: Authentication token for private models
290
+ :param revision: Model revision (branch, tag, or commit hash)
291
+ :param kwargs: Additional arguments for `check_download_model_config`
282
292
:return: Configuration dictionary
283
293
:raises TypeError: If config is not a supported type
284
294
:raises FileNotFoundError: If config.json cannot be found
@@ -301,7 +311,18 @@ def load_model_checkpoint_config_dict(
301
311
f"or PretrainedConfig, got { type (config )} "
302
312
)
303
313
304
- path = Path (config )
314
+ path = cast (
315
+ "Path" ,
316
+ check_download_model_config (
317
+ config ,
318
+ cache_dir = cache_dir ,
319
+ force_download = force_download ,
320
+ local_files_only = local_files_only ,
321
+ token = token ,
322
+ revision = revision ,
323
+ ** kwargs ,
324
+ ),
325
+ )
305
326
306
327
if path .is_dir ():
307
328
path = path / "config.json"
@@ -378,7 +399,7 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
378
399
Searches for weight files in various formats (.bin, .safetensors) through
379
400
automatic detection of different organization patterns.
380
401
381
- :param path: Local checkpoint directory, index file, or weight file path
402
+ :param path: HF ID, local checkpoint directory, index file, or weight file path
382
403
:return: List of paths to weight files
383
404
:raises TypeError: If path is not a string or Path-like object
384
405
:raises FileNotFoundError: If path doesn't exist or no weight files found
@@ -416,14 +437,26 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
416
437
417
438
def load_model_checkpoint_state_dict (
418
439
model : str | os .PathLike | PreTrainedModel | nn .Module ,
440
+ cache_dir : str | Path | None = None ,
441
+ force_download : bool = False ,
442
+ local_files_only : bool = False ,
443
+ token : str | bool | None = None ,
444
+ revision : str | None = None ,
445
+ ** kwargs ,
419
446
) -> dict [str , Tensor ]:
420
447
"""
421
448
Load complete model state dictionary from various sources.
422
449
423
450
Supports loading from model instances, local checkpoint directories,
424
451
or individual weight files with automatic format detection.
425
452
426
- :param model: Model instance, checkpoint directory, or weight file path
453
+ :param model: Model instance, HF ID, checkpoint directory, or weight file path
454
+ :param cache_dir: Directory to cache downloaded files
455
+ :param force_download: Whether to force re-download existing files
456
+ :param local_files_only: Only use cached files without downloading
457
+ :param token: Authentication token for private models
458
+ :param revision: Model revision (branch, tag, or commit hash)
459
+ :param kwargs: Additional arguments for `check_download_model_checkpoint`
427
460
:return: Dictionary mapping parameter names to tensors
428
461
:raises ValueError: If unsupported file format is encountered
429
462
"""
@@ -432,7 +465,17 @@ def load_model_checkpoint_state_dict(
432
465
return model .state_dict () # type: ignore[union-attr]
433
466
434
467
logger .debug (f"Loading model weights from: { model } " )
435
- weight_files = load_model_checkpoint_weight_files (model )
468
+ weight_files = load_model_checkpoint_weight_files (
469
+ check_download_model_checkpoint (
470
+ model ,
471
+ cache_dir = cache_dir ,
472
+ force_download = force_download ,
473
+ local_files_only = local_files_only ,
474
+ token = token ,
475
+ revision = revision ,
476
+ ** kwargs ,
477
+ )
478
+ )
436
479
437
480
state_dict = {}
438
481
0 commit comments