Skip to content

[bugfix] Fix auto thread-binding when world_size > 1 in CPU backend and refactor code #21032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .buildkite/scripts/hardware_ci/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .

# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2

function cpu_tests() {
set -e
Expand Down
10 changes: 7 additions & 3 deletions docs/getting_started/installation/cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ Currently, there are no pre-built CPU wheels.
## Related runtime environment variables

- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`.
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we remove VLLM_CPU_NUM_RESERVED_CPU since we still have it as an optional var?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Revert it.

I want to set this value in the worker based on some rules and don't expose it to users. However CPUWorker doesn't have enough usage context, users should set it manually in some cases.

- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).

Expand Down Expand Up @@ -123,9 +123,13 @@ export VLLM_CPU_NUM_OF_RESERVED_CPU=1
vllm serve facebook/opt-125m --dtype=bfloat16
```

Note, it is recommended to manually reserve 1 CPU for vLLM front-end process when `world_size == 1`.

### How to decide `VLLM_CPU_OMP_THREADS_BIND`?

- Bind each OpenMP thread to a dedicated physical CPU core respectively, or use auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to a same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If have any performance problems or unexpected binding behaviours, please try to bind threads as following.

- On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:

??? console "Commands"

Expand Down
2 changes: 0 additions & 2 deletions requirements/cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,4 @@ datasets # for benchmark scripts
# Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64"
intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
py-libnuma; platform_system != "Darwin"
psutil; platform_system != "Darwin"
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
5 changes: 3 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0
VLLM_CPU_NUM_OF_RESERVED_CPU: Optional[int] = None
VLLM_CPU_MOE_PREPACK: bool = True
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
Expand Down Expand Up @@ -441,7 +441,8 @@ def get_vllm_port() -> Optional[int]:
# (CPU backend only) CPU cores not used by OMP threads .
# Those CPU cores will not be used by OMP threads of a rank.
"VLLM_CPU_NUM_OF_RESERVED_CPU":
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")),
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,
Comment on lines +444 to +445
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for parsing VLLM_CPU_NUM_OF_RESERVED_CPU is vulnerable to a crash. If the environment variable is set to an empty string (e.g., export VLLM_CPU_NUM_OF_RESERVED_CPU=""), os.getenv will return "", and int("") will raise a ValueError, causing the application to terminate. To make it more robust, you should handle the empty string case, for example by treating it as 0.

Suggested change
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU") or "0")
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,


# (CPU backend only) whether to use prepack for MoE layer. This will be
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might
Expand Down
64 changes: 64 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import os
import platform
import subprocess
import sys
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -31,6 +34,35 @@ def get_max_threads(pid=0):
raise NotImplementedError("Unsupported OS")


@dataclass
class LogicalCPUInfo:
id: int = -1
physical_core: int = -1
numa_node: int = -1

@classmethod
def _int(cls, value: str) -> int:
try:
int_value = int(value)
except Exception:
int_value = -1
return int_value

@staticmethod
def json_decoder(obj_dict: dict):
id = obj_dict.get("cpu")
physical_core = obj_dict.get("core")
numa_node = obj_dict.get("node")

if not (id is None or physical_core is None or numa_node is None):
return LogicalCPUInfo(
id=LogicalCPUInfo._int(id),
physical_core=LogicalCPUInfo._int(physical_core),
numa_node=LogicalCPUInfo._int(numa_node))
else:
return obj_dict


class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_name: str = "cpu"
Expand Down Expand Up @@ -240,6 +272,38 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)

@classmethod
def get_allowed_cpu_memory_node_list(
cls) -> tuple[list[int], list[LogicalCPUInfo]]:
assert platform.system() == "Linux"

# Init LogicalCPUInfo from lscpu
lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE",
shell=True,
text=True)
logical_cpu_list: list[LogicalCPUInfo] = json.loads(
lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus']

# Filter CPUs with invalid attributes
logical_cpu_list = [
x for x in logical_cpu_list
if -1 not in (x.id, x.physical_core, x.numa_node)
]

# Filter allowed CPUs
allowed_cpu_id_list = os.sched_getaffinity(0)
logical_cpu_list = [
x for x in logical_cpu_list if x.id in allowed_cpu_id_list
]

# Get allowed NUMA nodes
allowed_numa_nodes = set()
for x in logical_cpu_list:
allowed_numa_nodes.add(x.numa_node) # type: ignore
allowed_numa_nodes_list = sorted(allowed_numa_nodes)

return allowed_numa_nodes_list, logical_cpu_list

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def replace_tensor(obj: Any, cpu_attr_name: str,
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch, k, k[:-11])

for k, v in vars(self.input_batch.block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch.block_table, k, k[:-4])
for block_table in self.input_batch.block_table.block_tables:
for k, v in vars(block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(block_table, k, k[:-4])

def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
Expand Down
Loading