Skip to content

[Core] Support model loader plugins #21067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/fastsafetensors_loader/test_fastsafetensors_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import SamplingParams
from vllm.config import LoadFormat

test_model = "openai-community/gpt2"

Expand All @@ -17,7 +16,6 @@


def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model,
load_format=LoadFormat.FASTSAFETENSORS) as llm:
with vllm_runner(test_model, load_format="fastsafetensors") as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
Empty file.
37 changes: 37 additions & 0 deletions tests/model_executor/model_loader/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from torch import nn

from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader import (get_model_loader,
register_model_loader)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader


@register_model_loader("custom_load_format")
class CustomModelLoader(BaseModelLoader):

def __init__(self, load_config: LoadConfig) -> None:
super().__init__(load_config)

def download_model(self, model_config: ModelConfig) -> None:
pass

def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
pass


def test_register_model_loader():
load_config = LoadConfig(load_format="custom_load_format")
assert isinstance(get_model_loader(load_config), CustomModelLoader)


def test_invalid_model_loader():
with pytest.raises(ValueError):

@register_model_loader("invalid_load_format")
class InValidModelLoader:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import SamplingParams
from vllm.config import LoadConfig, LoadFormat
from vllm.config import LoadConfig
from vllm.model_executor.model_loader import get_model_loader

load_format = "runai_streamer"
test_model = "openai-community/gpt2"

prompts = [
Expand All @@ -18,7 +19,7 @@


def get_runai_model_loader():
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
load_config = LoadConfig(load_format=load_format)
return get_model_loader(load_config)


Expand All @@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():


def test_runai_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
with vllm_runner(test_model, load_format=load_format) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
30 changes: 6 additions & 24 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader import LoadFormats
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

ConfigType = type[DataclassInstance]
Expand All @@ -78,6 +78,7 @@
QuantizationConfig = Any
QuantizationMethods = Any
BaseModelLoader = Any
LoadFormats = Any
TensorizerConfig = Any
ConfigType = type
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
Expand Down Expand Up @@ -1773,29 +1774,12 @@ def verify_with_parallel_config(
logger.warning("Possibly too large swap space. %s", msg)


class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors"


@config
@dataclass
class LoadConfig:
"""Configuration for loading the model weights."""

load_format: Union[str, LoadFormat,
"BaseModelLoader"] = LoadFormat.AUTO.value
load_format: Union[str, LoadFormats] = "auto"
"""The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n
Expand All @@ -1816,7 +1800,8 @@ class LoadConfig:
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by
Mistral models."""
Mistral models.
- Other custom values can be supported via plugins."""
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
Expand Down Expand Up @@ -1864,10 +1849,7 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
if isinstance(self.load_format, str):
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)

self.load_format = self.load_format.lower()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
Expand Down
28 changes: 13 additions & 15 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
TokenizerMode, VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
Expand All @@ -47,10 +46,12 @@
if TYPE_CHECKING:
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.model_loader import LoadFormats
from vllm.usage.usage_lib import UsageContext
else:
ExecutorBase = Any
QuantizationMethods = Any
LoadFormats = Any
UsageContext = Any

logger = init_logger(__name__)
Expand Down Expand Up @@ -276,7 +277,7 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir
load_format: str = LoadConfig.load_format
load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
Expand Down Expand Up @@ -545,9 +546,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
title="LoadConfig",
description=LoadConfig.__doc__,
)
load_group.add_argument("--load-format",
choices=[f.value for f in LoadFormat],
**load_kwargs["load_format"])
load_group.add_argument("--load-format", **load_kwargs["load_format"])
load_group.add_argument("--download-dir",
**load_kwargs["download_dir"])
load_group.add_argument("--model-loader-extra-config",
Expand Down Expand Up @@ -854,10 +853,9 @@ def create_model_config(self) -> ModelConfig:

# NOTE: This is to allow model loading from S3 in CI
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
and self.model in MODELS_ON_S3
and self.load_format == LoadFormat.AUTO): # noqa: E501
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
self.load_format = LoadFormat.RUNAI_STREAMER
self.load_format = "runai_streamer"

return ModelConfig(
model=self.model,
Expand Down Expand Up @@ -1261,7 +1259,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
#############################################################
# Unsupported Feature Flags on V1.

if self.load_format == LoadFormat.SHARDED_STATE.value:
if self.load_format == "sharded_state":
_raise_or_fallback(
feature_name=f"--load_format {self.load_format}",
recommend_to_remove=False)
Expand Down
114 changes: 87 additions & 27 deletions vllm/model_executor/model_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional
from typing import Literal, Optional

from torch import nn

from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader)
Expand All @@ -20,34 +21,92 @@
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture, get_model_cls)

logger = init_logger(__name__)

# Reminder: Please update docstring in `LoadConfig`
# if a new load format is added here
LoadFormats = Literal[
"auto",
"bitsandbytes",
"dummy",
"fastsafetensors",
"gguf",
"mistral",
"npcache",
"pt",
"runai_streamer",
"runai_streamer_sharded",
"safetensors",
"sharded_state",
"tensorizer",
]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
"gguf": GGUFModelLoader,
"mistral": DefaultModelLoader,
"npcache": DefaultModelLoader,
"pt": DefaultModelLoader,
"runai_streamer": RunaiModelStreamerLoader,
"runai_streamer_sharded": ShardedStateLoader,
"safetensors": DefaultModelLoader,
"sharded_state": ShardedStateLoader,
"tensorizer": TensorizerLoader,
}
Comment on lines +28 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could've remained an enums and would've supported a to_model_loader method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I do prefer using a Literal as it makes for nicer type hinting.

The way that @22quinn has organised the typing and the registry is exactly the same as for quantization methods. If we change one we should probably change both? Maybe as a follow up task to improve the way we handle built-in plugins in general?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! I don't have a strong opinion for this, but agree we'd better be consistent everywhere. I'm leaving it as Literal for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's keep them consistent for now



def register_model_loader(load_format: str):
"""Register a customized vllm model loader.

When a load format is not supported by vllm, you can register a customized
model loader to support it.

Args:
load_format (str): The model loader format name.

Examples:
>>> from vllm.config import LoadConfig
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
... def download_model(self):
... pass
...
... def load_weights(self):
... pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
""" # noqa: E501

def _wrapper(model_loader_cls):
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
logger.warning(
"Load format `%s` is already registered, and will be "
"overwritten by the new loader class `%s`.", load_format,
model_loader_cls)
if not issubclass(model_loader_cls, BaseModelLoader):
raise ValueError("The model loader must be a subclass of "
"`BaseModelLoader`.")
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
logger.info("Registered model loader `%s` with load format `%s`",
model_loader_cls, load_format)
return model_loader_cls

return _wrapper


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)

if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)

if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)

if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)

if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)

if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)

if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)

if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)

return DefaultModelLoader(load_config)
load_format = load_config.load_format
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
raise ValueError(f"Load format `{load_format}` is not supported")
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)


def get_model(*,
Expand All @@ -66,6 +125,7 @@ def get_model(*,
"get_architecture_class_name",
"get_model_architecture",
"get_model_cls",
"register_model_loader",
"BaseModelLoader",
"BitsAndBytesModelLoader",
"GGUFModelLoader",
Expand Down
Loading