Skip to content

Commit 9b9a10d

Browse files
authored
[Frontend] Dynamic RoPE scaling (vllm-project#4638)
1 parent 99eff67 commit 9b9a10d

File tree

5 files changed

+89
-12
lines changed

5 files changed

+89
-12
lines changed

tests/test_config.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,58 @@ def test_get_sliding_window():
3636
assert mistral_model_config.get_sliding_window() is None
3737

3838
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
39-
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
39+
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
40+
41+
42+
def test_rope_scaling():
43+
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
44+
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
45+
46+
llama_model_config = ModelConfig(
47+
"meta-llama/Meta-Llama-3-8B-Instruct",
48+
"meta-llama/Meta-Llama-3-8B-Instruct",
49+
tokenizer_mode="auto",
50+
trust_remote_code=False,
51+
dtype="float16",
52+
seed=0,
53+
)
54+
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
55+
assert llama_model_config.max_model_len == 8192
56+
57+
llama_model_config = ModelConfig(
58+
"meta-llama/Meta-Llama-3-8B-Instruct",
59+
"meta-llama/Meta-Llama-3-8B-Instruct",
60+
tokenizer_mode="auto",
61+
trust_remote_code=False,
62+
dtype="float16",
63+
seed=0,
64+
rope_scaling=TEST_ROPE_SCALING,
65+
)
66+
assert getattr(llama_model_config.hf_config, "rope_scaling",
67+
None) == TEST_ROPE_SCALING
68+
assert llama_model_config.max_model_len == 16384
69+
70+
longchat_model_config = ModelConfig(
71+
"lmsys/longchat-13b-16k",
72+
"lmsys/longchat-13b-16k",
73+
tokenizer_mode="auto",
74+
trust_remote_code=False,
75+
dtype="float16",
76+
seed=0,
77+
)
78+
assert getattr(longchat_model_config.hf_config, "rope_scaling",
79+
None) == LONGCHAT_ROPE_SCALING
80+
assert longchat_model_config.max_model_len == 16384
81+
82+
longchat_model_config = ModelConfig(
83+
"lmsys/longchat-13b-16k",
84+
"lmsys/longchat-13b-16k",
85+
tokenizer_mode="auto",
86+
trust_remote_code=False,
87+
dtype="float16",
88+
seed=0,
89+
rope_scaling=TEST_ROPE_SCALING,
90+
)
91+
assert getattr(longchat_model_config.hf_config, "rope_scaling",
92+
None) == TEST_ROPE_SCALING
93+
assert longchat_model_config.max_model_len == 4096

vllm/config.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class ModelConfig:
4545
code_revision: The specific revision to use for the model code on
4646
Hugging Face Hub. It can be a branch name, a tag name, or a
4747
commit id. If unspecified, will use the default version.
48+
rope_scaling: Dictionary containing the scaling configuration for the
49+
RoPE embeddings. When using this flag, don't update
50+
`max_position_embeddings` to the expected new maximum.
4851
tokenizer_revision: The specific tokenizer version to use. It can be a
4952
branch name, a tag name, or a commit id. If unspecified, will use
5053
the default version.
@@ -84,6 +87,7 @@ def __init__(
8487
seed: int,
8588
revision: Optional[str] = None,
8689
code_revision: Optional[str] = None,
90+
rope_scaling: Optional[dict] = None,
8791
tokenizer_revision: Optional[str] = None,
8892
max_model_len: Optional[int] = None,
8993
quantization: Optional[str] = None,
@@ -102,6 +106,7 @@ def __init__(
102106
self.seed = seed
103107
self.revision = revision
104108
self.code_revision = code_revision
109+
self.rope_scaling = rope_scaling
105110
self.tokenizer_revision = tokenizer_revision
106111
self.quantization = quantization
107112
self.quantization_param_path = quantization_param_path
@@ -116,7 +121,7 @@ def __init__(
116121
self.skip_tokenizer_init = skip_tokenizer_init
117122

118123
self.hf_config = get_config(self.model, trust_remote_code, revision,
119-
code_revision)
124+
code_revision, rope_scaling)
120125
self.hf_text_config = get_hf_text_config(self.hf_config)
121126
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
122127
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,

vllm/engine/arg_utils.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import dataclasses
3+
import json
34
from dataclasses import dataclass
45
from typing import List, Optional, Tuple, Union
56

@@ -49,6 +50,7 @@ class EngineArgs:
4950
disable_log_stats: bool = False
5051
revision: Optional[str] = None
5152
code_revision: Optional[str] = None
53+
rope_scaling: Optional[dict] = None
5254
tokenizer_revision: Optional[str] = None
5355
quantization: Optional[str] = None
5456
enforce_eager: bool = False
@@ -330,6 +332,11 @@ def add_cli_args(
330332
'None, we assume the model weights are not '
331333
'quantized and use `dtype` to determine the data '
332334
'type of the weights.')
335+
parser.add_argument('--rope-scaling',
336+
default=None,
337+
type=json.loads,
338+
help='RoPE scaling configuration in JSON format. '
339+
'For example, {"type":"dynamic","factor":2.0}')
333340
parser.add_argument('--enforce-eager',
334341
action='store_true',
335342
help='Always use eager-mode PyTorch. If False, '
@@ -548,11 +555,12 @@ def create_engine_config(self, ) -> EngineConfig:
548555
model_config = ModelConfig(
549556
self.model, self.tokenizer, self.tokenizer_mode,
550557
self.trust_remote_code, self.dtype, self.seed, self.revision,
551-
self.code_revision, self.tokenizer_revision, self.max_model_len,
552-
self.quantization, self.quantization_param_path,
553-
self.enforce_eager, self.max_context_len_to_capture,
554-
self.max_seq_len_to_capture, self.max_logprobs,
555-
self.skip_tokenizer_init, self.served_model_name)
558+
self.code_revision, self.rope_scaling, self.tokenizer_revision,
559+
self.max_model_len, self.quantization,
560+
self.quantization_param_path, self.enforce_eager,
561+
self.max_context_len_to_capture, self.max_seq_len_to_capture,
562+
self.max_logprobs, self.skip_tokenizer_init,
563+
self.served_model_name)
556564
cache_config = CacheConfig(self.block_size,
557565
self.gpu_memory_utilization,
558566
self.swap_space, self.kv_cache_dtype,

vllm/engine/llm_engine.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,11 @@ def __init__(
104104
"Initializing an LLM engine (v%s) with config: "
105105
"model=%r, speculative_config=%r, tokenizer=%r, "
106106
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
107-
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
108-
"max_seq_len=%d, download_dir=%r, load_format=%s, "
109-
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
110-
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
107+
"rope_scaling=%r, tokenizer_revision=%s, "
108+
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
109+
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
110+
"disable_custom_all_reduce=%s, quantization=%s, "
111+
"enforce_eager=%s, kv_cache_dtype=%s, "
111112
"quantization_param_path=%s, device_config=%s, "
112113
"decoding_config=%r, seed=%d, served_model_name=%s)",
113114
vllm.__version__,
@@ -117,6 +118,7 @@ def __init__(
117118
model_config.skip_tokenizer_init,
118119
model_config.tokenizer_mode,
119120
model_config.revision,
121+
model_config.rope_scaling,
120122
model_config.tokenizer_revision,
121123
model_config.trust_remote_code,
122124
model_config.dtype,

vllm/transformers_utils/config.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
from transformers import AutoConfig, PretrainedConfig
44

5+
from vllm.logger import init_logger
56
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
67
JAISConfig, MPTConfig, RWConfig)
78

9+
logger = init_logger(__name__)
10+
811
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
912
"chatglm": ChatGLMConfig,
1013
"dbrx": DbrxConfig,
@@ -18,7 +21,8 @@
1821
def get_config(model: str,
1922
trust_remote_code: bool,
2023
revision: Optional[str] = None,
21-
code_revision: Optional[str] = None) -> PretrainedConfig:
24+
code_revision: Optional[str] = None,
25+
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
2226
try:
2327
config = AutoConfig.from_pretrained(
2428
model,
@@ -41,6 +45,10 @@ def get_config(model: str,
4145
config = config_class.from_pretrained(model,
4246
revision=revision,
4347
code_revision=code_revision)
48+
if rope_scaling is not None:
49+
logger.info("Updating rope_scaling from %r to %r",
50+
getattr(config, "rope_scaling", None), rope_scaling)
51+
config.update({"rope_scaling": rope_scaling})
4452
return config
4553

4654

0 commit comments

Comments
 (0)