Skip to content

Commit d54e516

Browse files
committed
add cpu core pinning to vllm-server on Gaudi3 + GNR for Llama405B and 70B
1 parent 0924722 commit d54e516

File tree

8 files changed

+414
-2
lines changed

8 files changed

+414
-2
lines changed

.cd/README.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,37 @@ cd vllm-gaudi/.cd/
6464

6565
This launches the vLLM server and runs the benchmark suite automatically.
6666

67+
#### 2.1 (Optional) Running the Server with a Benchmark, and pinning CPU cores for memory access coherence
68+
69+
To improve memory access cohererence and release CPUs to other CPU only workloads like a vLLM serving with Llama3 8B,
70+
pin the CPU cores based on different CPU NUMA nodes by using an auto-generate docker-compose.override.yml file.
71+
Couple python libraries are needed for the python scripts, so install the required packages using following commnad.
72+
```bash
73+
pip install -r vllm-fork/.cd/server/requirements_cpu_binding.txt
74+
```
75+
Run below command to do CPU cores pinning via auto-generated docker-compose.override.yml file.
76+
```bash
77+
cd vllm-fork/.cd/
78+
MODEL="Qwen/Qwen2.5-14B-Instruct" \
79+
HF_TOKEN="<your huggingface token>" \
80+
DOCKER_IMAGE="vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/vllm-installer-2.7.1:latest" \
81+
python3 server/generate_cpu_binding_from_csv.py --settings server/cpu_binding.csv --output ./docker-compose.override.yml \
82+
docker compose --profile benchmark -f docker-compose.yml -f docker-compose.override.yml up
83+
```
84+
85+
To also pin idle CPUs to another service like vllm-cpu-service, please give the service name to update
86+
docker-compose.override.yml in order to bind another service to idle cpus.
87+
Here is an exmaple to bind idle cpu for vllm-cpu-service service while docker-compose.vllm-cpu-service.yml defines cpu service.
88+
89+
```bash
90+
cd vllm-fork/.cd/
91+
MODEL="Qwen/Qwen2.5-14B-Instruct" \
92+
HF_TOKEN="<your huggingface token>" \
93+
DOCKER_IMAGE="vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/vllm-installer-2.7.1:latest" \
94+
python3 server/generate_cpu_binding_from_csv.py --settings server/cpu_binding.csv --output ./docker-compose.override.yml --cpuservice vllm-cpu-service \
95+
docker compose --profile benchmark -f docker-compose.yml -f docker-compose.vllm-cpu-service.yml -f docker-compose.override.yml up
96+
```
97+
6798
### 3. Run the server using Docker Compose with custom parameters
6899

69100
To override default settings, you can provide additional parameters when starting the server. This is a more advanced approach:
@@ -129,7 +160,7 @@ cd vllm-gaudi/.cd/
129160
MAX_MODEL_LEN=2048 \
130161
INPUT_TOK=128 \
131162
OUTPUT_TOK=128 \
132-
CON_REQ=16 \
163+
CONCURRENT_REQ=16 \
133164
NUM_PROMPTS=64 \
134165
docker compose --profile benchmark up
135166
```

.cd/benchmark/benchmark_user.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
MODEL
22
INPUT_TOK
33
OUTPUT_TOK
4-
CON_REQ
4+
CONCURRENT_REQ
55
NUM_PROMPTS

.cd/docker-compose.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,6 @@ services:
4242
- PYTHONUNBUFFERED=1
4343
env_file:
4444
- ./benchmark/benchmark_user.env
45+
volumes:
46+
- ./logs:/root/scripts/logs
4547
command: ["benchmark", "--config-file", "${VLLM_BENCHMARK_CONFIG_FILE}", "--config-name", "${VLLM_BENCHMARK_CONFIG_NAME}"]

.cd/server/cpu_binding.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import csv
4+
from importlib import util
5+
from typing import Optional
6+
from enum import Enum
7+
from gaudi_topology import GaudiTopology
8+
from typing import List, Tuple
9+
REQUIRED_COLUMNS = ["model_id", "input_length", "output_length", "world_size", "data_type","num_allocated_cpu"]
10+
11+
class BindingPolicy(Enum):
12+
Evenly_on_NUMAs = "evenly"
13+
NUMAs_with_cards = "close2cards"
14+
15+
16+
class CPU_Binding():
17+
18+
def __init__(self,
19+
csv_path: str = "cpu_binding_gnr.csv",
20+
use_hyperthread: bool = False):
21+
self.libnuma_found = util.find_spec("numa") is not None
22+
self.psutil_found = util.find_spec("psutil") is not None
23+
if self.libnuma_found and self.psutil_found:
24+
import psutil
25+
from numa import info
26+
# Get system Info
27+
self.cpu_count = psutil.cpu_count(logical=False)
28+
self.cpus_allow_list = psutil.Process().cpu_affinity()
29+
#print("cpu allow list:",self.cpus_allow_list)
30+
self.numa_size = info.get_num_configured_nodes()
31+
self.cpu_count_per_numa =self.cpu_count // self.numa_size
32+
33+
# Get CSV info
34+
with open(csv_path, newline="") as f:
35+
rows = list(csv.DictReader(f))
36+
if not rows or any(col not in rows[0] for col in REQUIRED_COLUMNS):
37+
found = list(rows[0].keys()) if rows else "EMPTY CSV"
38+
raise ValueError(f"CSV missing required headers {REQUIRED_COLUMNS}. Found: {found}")
39+
model = os.environ.get("MODEL")
40+
if not model:
41+
raise RuntimeError("Set environment variable MODEL to a model_id in the CSV (e.g., export MODEL='meta-llama/Llama-3.1-8B-Instruct').")
42+
input_tok = os.environ.get("INPUT_TOK")
43+
output_tok = os.environ.get("OUTPUT_TOK")
44+
con_req = os.environ.get("CONCURRENT_REQ")
45+
num_allocated_cpu = os.environ.get("NUM_CPUS")
46+
print(num_allocated_cpu)
47+
48+
row = self.pick_row_by_parameters(rows, model, input_tok, output_tok, con_req)
49+
print(row["num_allocated_cpu"])
50+
51+
self.world_size = self.parse_int(row["world_size"], "world_size")
52+
binding_policy_index = self.parse_int(row["binding_policy"], "binding_policy")
53+
self.binding_policy = list(BindingPolicy)[binding_policy_index]
54+
55+
if num_allocated_cpu:
56+
self.num_allocated_cpu = int(num_allocated_cpu)
57+
elif row["num_allocated_cpu"] == 'NA':
58+
raise RuntimeError("Invalid NUM_CPU value. Set environment variable NUM_CPUS instead .")
59+
else:
60+
self.num_allocated_cpu = self.parse_int(row["num_allocated_cpu"], "num_allocated_cpu")
61+
62+
# CPU
63+
# check allow node_to_cpus list
64+
self.node_to_cpus = []
65+
for i in range(self.numa_size):
66+
from numa import info
67+
node_intersect = [cpu for cpu in info.node_to_cpus(i) if cpu in self.cpus_allow_list]
68+
if bool(node_intersect):
69+
self.node_to_cpus.append(list(node_intersect))
70+
self.node_to_idle_cpus = self.node_to_cpus.copy()
71+
#self.node_to_idle_cpus_ht = [] #self.node_to_cpus
72+
for i in range(self.numa_size):
73+
if use_hyperthread is False:
74+
self.node_to_idle_cpus[i] = self.node_to_cpus[i][:self.cpu_count_per_numa]
75+
else:
76+
self.node_to_idle_cpus[i] = self.node_to_cpus[i][self.cpu_count_per_numa:]
77+
# Gaudi
78+
topo = GaudiTopology()
79+
self.cards = topo.get_cards()
80+
if self.cards != None:
81+
self.gaudi_numa_list=[]
82+
# Assume to use cards from 0 to 7
83+
for card in self.cards[:self.world_size]:
84+
if card['numa_node'] not in self.gaudi_numa_list:
85+
self.gaudi_numa_list.append(card['numa_node'])
86+
print(f"Card {card['card_id']} ({card['model']}):")
87+
print(f" Bus ID : {card['bus_id']}")
88+
print(f" NUMA Node : {card['numa_node']}")
89+
print(f" Local CPUs : {card['local_cpulist']}")
90+
91+
def parse_int(self, v: str, name: str) -> int:
92+
try:
93+
return int(v)
94+
except Exception:
95+
raise ValueError(f"Invalid integer for {name!r}: {v!r}")
96+
97+
def pick_row_by_parameters(self, rows: List[dict], model: str, input_tok: str, output_tok: str, con_req: str) -> dict:
98+
matches = [
99+
r for r in rows
100+
if r.get("model_id", "").strip() == model
101+
if r.get("input_length", "").strip() == input_tok
102+
if r.get("output_length", "").strip() == output_tok
103+
]
104+
if not matches:
105+
available = ", ".join(sorted({r.get('model_id','') for r in rows}))
106+
raise ValueError(f"MODEL '{model}', input_lenght '{input_tok}', output_length '{output_tok}' not found in CSV. Available: {available}")
107+
return matches[0]
108+
109+
def get_cpus_id_binding_based_on_numa_nodes(self,
110+
rank: int) -> str:
111+
"""Return CPUs id binding based on NUMA nodes.
112+
"""
113+
rank_to_cpus = ''
114+
if not self.libnuma_found or not self.psutil_found:
115+
print(
116+
"Auto thread-binding is not supported due to "
117+
"the lack of package numa and psutil,"
118+
"fallback to no thread-binding. To get better performance,"
119+
"please try to manually bind threads.")
120+
return rank_to_cpus
121+
122+
if self.binding_policy is BindingPolicy.Evenly_on_NUMAs or self.cards is None:
123+
divider = min (self.world_size, len(self.node_to_cpus))
124+
self.allocated_cpu_per_numa = self.num_allocated_cpu // divider
125+
node_id = rank
126+
elif self.binding_policy is BindingPolicy.NUMAs_with_cards:
127+
self.allocated_cpu_per_numa = self.num_allocated_cpu // len(self.gaudi_numa_list)
128+
node_id = int(self.cards[rank]['numa_node'])
129+
130+
print("binding numa node_id %d allocated_cpu_per_numa %d", node_id, self.allocated_cpu_per_numa)
131+
# Option 1. Bind to the last N cpu cores
132+
start = self.cpu_count_per_numa - self.allocated_cpu_per_numa
133+
rank_to_cpus_list = self.node_to_cpus[node_id][start:self.cpu_count_per_numa]
134+
# Option 2. Bind to the first N cpu cores
135+
#rank_to_cpus_list = self.node_to_cpus[node_id][:self.allocated_cpu_per_numa]
136+
137+
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
138+
print("rank %d auto thread-binding list: %s", rank, rank_to_cpus)
139+
self.node_to_idle_cpus[node_id] = [cpu for cpu in self.node_to_idle_cpus[node_id] if cpu not in rank_to_cpus_list]
140+
return rank_to_cpus
141+
142+
if __name__=="__main__":
143+
libnuma_found = util.find_spec("numa") is not None
144+
if libnuma_found:
145+
from numa import info
146+
numa_size = info.get_num_configured_nodes()
147+
else:
148+
numa_size = 1
149+
world_size = numa_size
150+
cpu_binder = CPU_Binding(use_hyperthread=False)
151+
max_needed_numa_size = min(cpu_binder.world_size, cpu_binder.numa_size)
152+
for i in range(max_needed_numa_size):
153+
rank_to_cpus = cpu_binder.get_cpus_id_binding_based_on_numa_nodes(i)
154+
print(rank_to_cpus)
155+
156+
157+
rank_to_idle_cpus = ','.join(str(x) for row in cpu_binder.node_to_idle_cpus for x in row)
158+
print(rank_to_idle_cpus)
159+
for r in cpu_binder.node_to_idle_cpus:
160+
print(len(r))

.cd/server/cpu_binding_gnr.csv

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
model_id,input_length,output_length,world_size,data_type,num_allocated_cpu,binding_policy
2+
meta-llama/Llama-3.1-405B-Instruct,128,4096,8,bf16,18,0
3+
meta-llama/Llama-3.1-405B-Instruct,2048,2048,8,bf16,18,0
4+
meta-llama/Llama-3.1-405B-Instruct,4096,128,8,bf16,18,0
5+
meta-llama/Llama-3.1-70B-Instruct,128,4096,4,bf16,12,0
6+
meta-llama/Llama-3.1-70B-Instruct,2048,2048,4,bf16,12,0
7+
meta-llama/Llama-3.1-70B-Instruct,4096,128,4,bf16,12,0

.cd/server/gaudi_topology.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
# ==============================================================================
3+
# gaudi_topology.py
4+
# Provides GaudiTopology class:
5+
# - discover all Gaudi cards via hl-smi
6+
# - return NUMA node and CPU IDs per card
7+
# Works with hl-smi v1.22.0+ (HL-325L / Gaudi3) table format.
8+
# ==============================================================================
9+
10+
import subprocess
11+
import re
12+
import os
13+
from typing import List, Dict, Optional
14+
import shutil
15+
16+
class GaudiTopology:
17+
"""Utility class to discover Gaudi cards and their NUMA / CPU locality."""
18+
19+
def __init__(self):
20+
self.cards = self._discover_cards()
21+
22+
# ------------------------------------------------------------------
23+
def _run_cmd(self, cmd: str) -> str:
24+
"""Run a shell command and return stdout."""
25+
try:
26+
result = subprocess.run(cmd, shell=True, check=True,
27+
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
28+
text=True)
29+
return result.stdout
30+
except subprocess.CalledProcessError as e:
31+
raise RuntimeError(f"Command failed: {cmd}\n{e.stderr}")
32+
33+
# ------------------------------------------------------------------
34+
def _parse_hl_smi_table(self, text: str) -> List[Dict]:
35+
"""
36+
Parse hl-smi v1.22+ table format.
37+
Example line:
38+
| 0 HL-325L N/A | 0000:97:00.0 N/A | ...
39+
"""
40+
cards = []
41+
pattern = re.compile(
42+
r'^\|\s*(\d+)\s+([A-Z0-9-]+)\s+N/A\s+\|\s*([0-9a-fA-F:.]+)\s+N/A\s*\|'
43+
)
44+
for line in text.splitlines():
45+
match = pattern.match(line)
46+
if not match:
47+
continue
48+
card_id, model, bus_id = match.groups()
49+
if not bus_id.startswith("0000:"):
50+
bus_id = "0000:" + bus_id
51+
cards.append({
52+
"card_id": int(card_id),
53+
"model": model,
54+
"bus_id": bus_id
55+
})
56+
return cards
57+
58+
# ------------------------------------------------------------------
59+
def _get_sysfs_info(self, bus_id: str) -> Dict[str, Optional[str]]:
60+
"""Fetch NUMA node and local CPU list from sysfs."""
61+
sys_path = f"/sys/bus/pci/devices/{bus_id}"
62+
info = {"numa_node": None, "local_cpulist": None}
63+
try:
64+
with open(os.path.join(sys_path, "numa_node")) as f:
65+
info["numa_node"] = f.read().strip()
66+
except FileNotFoundError:
67+
pass
68+
try:
69+
with open(os.path.join(sys_path, "local_cpulist")) as f:
70+
info["local_cpulist"] = f.read().strip()
71+
except FileNotFoundError:
72+
pass
73+
return info
74+
75+
# ------------------------------------------------------------------
76+
def _discover_cards(self) -> List[Dict]:
77+
"""Run hl-smi and discover Gaudi cards."""
78+
if shutil.which("hl-smi") is None:
79+
print("No hl-smi found")
80+
return None
81+
82+
hl_smi_output = self._run_cmd("hl-smi")
83+
cards = self._parse_hl_smi_table(hl_smi_output)
84+
for c in cards:
85+
sysfs_info = self._get_sysfs_info(c["bus_id"])
86+
c.update(sysfs_info)
87+
return cards
88+
89+
# ------------------------------------------------------------------
90+
def get_cards(self) -> List[Dict]:
91+
"""Return list of all discovered cards sorted by NUMA node (then card_id)."""
92+
def sort_key(c):
93+
# Convert numa_node to int when possible, else put N/A at the end
94+
try:
95+
return (int(c["numa_node"]), c["card_id"])
96+
except (TypeError, ValueError):
97+
return (999, c["card_id"])
98+
return sorted(self.cards, key=sort_key)
99+
100+
101+
# ------------------------------------------------------------------
102+
def get_numa_for_card(self, card_id: int) -> Optional[str]:
103+
"""Return NUMA node for a given card ID."""
104+
for c in self.cards:
105+
if c["card_id"] == card_id:
106+
return c["numa_node"]
107+
return None
108+
109+
# ------------------------------------------------------------------
110+
def get_cpus_for_card(self, card_id: int) -> Optional[str]:
111+
"""Return local CPU list for a given card ID."""
112+
for c in self.cards:
113+
if c["card_id"] == card_id:
114+
return c["local_cpulist"]
115+
return None
116+
117+
# ------------------------------------------------------------------------------
118+
119+
if __name__ == "__main__":
120+
topo = GaudiTopology()
121+
for card in topo.get_cards():
122+
print(f"Card {card['card_id']} ({card['model']}):")
123+
print(f" Bus ID : {card['bus_id']}")
124+
print(f" NUMA Node : {card['numa_node']}")
125+
print(f" Local CPUs : {card['local_cpulist']}")
126+
print()

0 commit comments

Comments
 (0)