File tree 5 files changed +35
-11
lines changed
examples/offline_inference
5 files changed +35
-11
lines changed Original file line number Diff line number Diff line change @@ -22,7 +22,8 @@ def main():
22
22
# In real workloads, `enforace_eager` should be `False`.
23
23
llm = LLM (model = "Qwen/Qwen2-1.5B-Instruct" ,
24
24
max_num_batched_tokens = 64 ,
25
- max_num_seqs = 4 )
25
+ max_num_seqs = 4 ,
26
+ max_model_len = 128 )
26
27
outputs = llm .generate (prompts , sampling_params )
27
28
print ("-" * 50 )
28
29
for output , answer in zip (outputs , answers ):
Original file line number Diff line number Diff line change @@ -18,9 +18,9 @@ setuptools==78.1.0
18
18
--find-links https://storage.googleapis.com/libtpu-releases/index.html
19
19
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
20
20
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
21
- torch==2.8.0.dev20250408
22
- torchvision==0.22.0.dev20250408
23
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
21
+ torch==2.8.0.dev20250430
22
+ torchvision==0.22.0.dev20250430
23
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
26
26
Original file line number Diff line number Diff line change @@ -76,9 +76,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
76
76
from vllm .config import CompilationLevel
77
77
78
78
cache_config = vllm_config .cache_config
79
+ # For v0, the default block size is 16.
79
80
if cache_config and cache_config .block_size is None :
80
81
cache_config .block_size = 16
81
-
82
82
compilation_config = vllm_config .compilation_config
83
83
84
84
# TPU only supports DYNAMO_ONCE compilation level
@@ -101,16 +101,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
101
101
if envs .VLLM_USE_V1 :
102
102
from vllm .v1 .attention .backends .pallas import (
103
103
PallasAttentionBackend )
104
+ cache_config .block_size = PallasAttentionBackend .get_page_size (
105
+ vllm_config )
104
106
min_page_size = PallasAttentionBackend .get_min_page_size (
105
107
vllm_config )
106
- if min_page_size > vllm_config . cache_config .block_size :
108
+ if min_page_size > cache_config .block_size :
107
109
logger .warning (
108
110
"Increase the page size from %s to %s to make sure there's"
109
111
"no SMEM OOM" ,
110
- vllm_config . cache_config .block_size ,
112
+ cache_config .block_size ,
111
113
min_page_size ,
112
114
)
113
- vllm_config . cache_config .block_size = min_page_size
115
+ cache_config .block_size = min_page_size
114
116
115
117
parallel_config = vllm_config .parallel_config
116
118
scheduler_config = vllm_config .scheduler_config
Original file line number Diff line number Diff line change @@ -707,6 +707,13 @@ def cdiv(a: int, b: int) -> int:
707
707
return - (a // - b )
708
708
709
709
710
+ def next_power_of_2 (n ) -> int :
711
+ """The next power of 2 (inclusive)"""
712
+ if n < 1 :
713
+ return 1
714
+ return 1 << (n - 1 ).bit_length ()
715
+
716
+
710
717
def round_up (x : int , y : int ) -> int :
711
718
return ((x + y - 1 ) // y ) * y
712
719
Original file line number Diff line number Diff line change 12
12
from vllm .attention .backends .utils import CommonAttentionState
13
13
from vllm .config import VllmConfig
14
14
from vllm .logger import init_logger
15
- from vllm .utils import cdiv
15
+ from vllm .utils import cdiv , next_power_of_2
16
16
17
17
logger = init_logger (__name__ )
18
18
@@ -65,6 +65,20 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
65
65
min_page_size = 1 << (min_page_size - 1 ).bit_length ()
66
66
return min_page_size
67
67
68
+ # TPU has limited SREGs (scalar registers), if page_size is too small, we
69
+ # can spill SREGs easily which leads to bad performance. The strategy we
70
+ # apply here is trying to split max-model-len to 16 pages which make the
71
+ # spill less likely. Meanwhile we make sure the page size is in [16, 256].
72
+ @staticmethod
73
+ def get_page_size (vllm_config : VllmConfig ) -> int :
74
+ page_size = next_power_of_2 (
75
+ vllm_config .model_config .max_model_len ) // 16
76
+ if page_size <= 16 :
77
+ return 16
78
+ if page_size >= 256 :
79
+ return 256
80
+ return page_size
81
+
68
82
69
83
@dataclass
70
84
class PallasMetadata :
You can’t perform that action at this time.
0 commit comments