Skip to content

[bugfix] Fix auto thread-binding when world_size > 1 in CPU backend and refactor code #21032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 19, 2025

Conversation

bigPYJ1151
Copy link
Contributor

@bigPYJ1151 bigPYJ1151 commented Jul 16, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

  • Current CPU backend auto thread-binding works unexpectedly when world_size>1. This PR fixed it.
  • There are many duplicate code in x86 and POWERPC auto thread-binding. This PR merged them.
  • Enable auto thread-binding in x86 CI tests.

Test Plan

Test Result

(Optional) Documentation Update

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request fixes an auto thread-binding bug and refactors code related to CPU affinity in vLLM. The changes include modifications to shell scripts, documentation, environment variable handling, and CPU worker logic to improve robustness and performance.

Comment on lines +444 to +445
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for parsing VLLM_CPU_NUM_OF_RESERVED_CPU is vulnerable to a crash. If the environment variable is set to an empty string (e.g., export VLLM_CPU_NUM_OF_RESERVED_CPU=""), os.getenv will return "", and int("") will raise a ValueError, causing the application to terminate. To make it more robust, you should handle the empty string case, for example by treating it as 0.

Suggested change
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU") or "0")
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,

Comment on lines 166 to 140
assert len(logical_cpu_list) != 0, (
f"No allowed CPU on NUMA node {self.local_rank}. "
f"Allowed CPU ids are {allowed_cpu_id_list}. "
"Their NUMA nodes can be got via `lscpu`. "
"Please try to bind threads manually.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion message is missing a %s placeholder for allowed_cpu_id_list. This could make debugging difficult if the assertion fails, as the user will not see the value of the allowed CPU IDs.

            f"No allowed CPU on NUMA node {self.local_rank}. "
            f"Allowed CPU ids are {allowed_cpu_id_list}. "
            "Their NUMA nodes can be got via `lscpu`. "
            "Please try to bind threads manually.")

@bigPYJ1151
Copy link
Contributor Author

cc @Akashcodes732 for the validation on POWERPC, thanks.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@Akashcodes732
Copy link
Contributor

Thanks @bigPYJ1151 , I will validate this for POWERPC

Copy link
Contributor

@louie-tsai louie-tsai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remind me what is the original issue? is the issue introduced from ppc64le support?

@@ -94,8 +94,7 @@ Currently, there are no pre-built CPU wheels.
## Related runtime environment variables

- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`.
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we remove VLLM_CPU_NUM_RESERVED_CPU since we still have it as an optional var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Revert it.

I want to set this value in the worker based on some rules and don't expose it to users. However CPUWorker doesn't have enough usage context, users should set it manually in some cases.

else:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes())
self.local_omp_cpuid = "all"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it this case, "users couldn't set omp_cpuid=all.
why don't we let users to set the "all" themselves if they don't want any binding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think auto binding is good enough as a default option.

"please try to manually bind threads.")
return rank_to_cpus
@dataclass
class LogicalCPUInfo:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those codes look much complicated than previous one. why do we want to change original simple codes to those complicated codes? why don't we just use well-maintained pstuil and libnuma to find out those information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks a bit complicated. In fact, there are just some operations (filter, slice, ...) on a list.

I feel the APIs of pstuil and libnumaare not flexible to provide a data structure to express CPU attributes and layout, the custom data structure can support SMT-2 (x86), SMT-8 (ppc64le) and further features more easily.

@bigPYJ1151
Copy link
Contributor Author

could you remind me what is the original issue? is the issue introduced from ppc64le support?

@louie-tsai I think it probably is a original issue, as ppc64le implemented auto-bind in another method.

When we add -tp=4 and launch multiple workers, current thread binding lists are:

(VllmWorker rank=0 pid=1701267) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
(VllmWorker rank=1 pid=1701268) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 178,179,163,180,181,182,183,184,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55
(VllmWorker rank=2 pid=1701269) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 204,219,205,206,207,194,208,209,220,210,211,212,195,213,214,215,216,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78
(VllmWorker rank=3 pid=1701270) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 229,255,230,224,231,232,250,233,234,235,225,228,236,248,237,251,238,249,239,240,226,241,242,252,243,244,245,227,254,246,247,253

Only rank0 is as expected.

@bigPYJ1151
Copy link
Contributor Author

Hi @ericcurtin, I notice your fix in #21115 relates to this PR.

I am refactoring the auto thread-binding procedure. I guess this PR can also resolve the bug you mentioned. Would you please help to check this? Thanks :)

@ericcurtin
Copy link
Contributor

ericcurtin commented Jul 18, 2025

What I think is more important than a major refactoring of this code AGAIN, is setting up CI to do some basic testing on non-x86_64 architecture or aarch64 in other words, since it's available on CI and it's the most popular non-x86_64 platform these days, something like this:

vllm serve HuggingFaceTB/SmolLM2-135M-Instruct

and a quick query against it. One new issue that somehow got introduced in the last few commits is on aarch64 Linux VMs on macbooks you have to do:

vllm serve --max-num-batched-tokens 8192 HuggingFaceTB/SmolLM2-135M-Instruct

to avoid this happening:

  Value error, max_num_batched_tokens (2048) is smaller than max_model_len (8192). This effectively limits the maximum sequence length to max_num_batched_tokens and makes vLLM reject longer sequences. Please increase max_num_batched_tokens or decrease max_model_len. [type=value_error, input_value=ArgsKwargs((), {'runner_t...ync_scheduling': False}), input_type=ArgsKwargs]

Why I don't know? You didn't need this flag until recently for aarch64. And it's not needed when you do the exact same thing on x86_64. And you don't have to avoid error in llama.cpp by adding little flags like this. Would be nice to achieve some stability here and not having to dance around issues.

@ericcurtin
Copy link
Contributor

@bigPYJ1151 tested these changes manually LGTM

@bigPYJ1151
Copy link
Contributor Author

@ericcurtin Thanks for your checking :)

You are right. Setup effective regular CI jobs for non-x86 CPUs can discover problems in time. If have resources for it, x86 CI scripts can be used directly.

@bigPYJ1151 bigPYJ1151 changed the title [bugfix][WIP] Fix auto thread-binding when world_size > 1 in CPU backend and refactor code [bugfix] Fix auto thread-binding when world_size > 1 in CPU backend and refactor code Jul 18, 2025
@bigPYJ1151
Copy link
Contributor Author

Hi @DarkLight1337 @Isotr0py Would you please help to check this PR? Thanks :)
Totally refactored the auto thread-binding procedure to make it more robust and easier to extend. Enabled the auto thread-binding in the CPU CI and it worked as expected.

logical_cpu_list = []
for cpu_list in core_to_cpus.values():
cpu_list = sorted(cpu_list, key=lambda x: x.id)
logical_cpu_list.extend(cpu_list[-cpu_num_per_core:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bigPYJ1151 ,
Here the code selects cpu_ids from the last, which would cause issues with the PowerPC.

As in SMT-8 Mode cpu_ids for core 0 would be
[0,1,2,3,4,5,6,7] and we select [4,5,6,7]

But if the user switches the Smt mode to 4 then core 0 ids would be
[0,1,2,3,-,-,-,-] and the logic would fail then.

Can we select cpu_ids from the start or add a platform check ?

Copy link
Contributor Author

@bigPYJ1151 bigPYJ1151 Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using SMT-4, I think the logical CPU list of core 0 should be [0,1,2,3]?

Could you please share lscpu -e=CPU,CORE,NODE output on a SMT-4 machine?

Copy link
Contributor

@Akashcodes732 Akashcodes732 Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default the Power system are SMT-8 systems and lscpu -e=CPU,CORE,NODE looks like this

Show lscpu -e=CPU,CORE,NODE output (SMT-8)
CPU CORE NODE
  0    0    0
  1    0    0
  2    0    0
  3    0    0
  4    0    0
  5    0    0
  6    0    0
  7    0    0
  8    1    0
  9    1    0
 10    1    0
 11    1    0
 12    1    0
 13    1    0
 14    1    0
 15    1    0
 16    2    0
 17    2    0
 18    2    0
 19    2    0
 20    2    0
 21    2    0
 22    2    0
 23    2    0
 24    3    0
 25    3    0
 26    3    0
 27    3    0
 28    3    0
 29    3    0
 30    3    0
 31    3    0
 32    4    0
 33    4    0
 34    4    0
 35    4    0
 36    4    0
 37    4    0
 38    4    0
 39    4    0
 40    5    0
 41    5    0
 42    5    0
 43    5    0
 44    5    0
 45    5    0
 46    5    0
 47    5    0
 48    6    0
 49    6    0
 50    6    0
 51    6    0
 52    6    0
 53    6    0
 54    6    0
 55    6    0
 56    7    0
 57    7    0
 58    7    0
 59    7    0
 60    7    0
 61    7    0
 62    7    0
 63    7    0
 64    8    0
 65    8    0
 66    8    0
 67    8    0
 68    8    0
 69    8    0
 70    8    0
 71    8    0
 72    9    0
 73    9    0
 74    9    0
 75    9    0
 76    9    0
 77    9    0
 78    9    0
 79    9    0
 80   10    0
 81   10    0
 82   10    0
 83   10    0
 84   10    0
 85   10    0
 86   10    0
 87   10    0
 88   11    0
 89   11    0
 90   11    0
 91   11    0
 92   11    0
 93   11    0
 94   11    0
 95   11    0
 96   12    1
 97   12    1
 98   12    1
 99   12    1
100   12    1
101   12    1
102   12    1
103   12    1
104   13    1
105   13    1
106   13    1
107   13    1
108   13    1
109   13    1
110   13    1
111   13    1
112   14    1
113   14    1
114   14    1
115   14    1
116   14    1
117   14    1
118   14    1
119   14    1
120   15    1
121   15    1
122   15    1
123   15    1
124   15    1
125   15    1
126   15    1
127   15    1
128   16    1
129   16    1
130   16    1
131   16    1
132   16    1
133   16    1
134   16    1
135   16    1
136   17    1
137   17    1
138   17    1
139   17    1
140   17    1
141   17    1
142   17    1
143   17    1
144   18    1
145   18    1
146   18    1
147   18    1
148   18    1
149   18    1
150   18    1
151   18    1
152   19    1
153   19    1
154   19    1
155   19    1
156   19    1
157   19    1
158   19    1
159   19    1
160   20    1
161   20    1
162   20    1
163   20    1
164   20    1
165   20    1
166   20    1
167   20    1
168   21    1
169   21    1
170   21    1
171   21    1
172   21    1
173   21    1
174   21    1
175   21    1
176   22    1
177   22    1
178   22    1
179   22    1
180   22    1
181   22    1
182   22    1
183   22    1
184   23    1
185   23    1
186   23    1
187   23    1
188   23    1
189   23    1
190   23    1
191   23    1
192   24    2
193   24    2
194   24    2
195   24    2
196   24    2
197   24    2
198   24    2
199   24    2
200   25    2
201   25    2
202   25    2
203   25    2
204   25    2
205   25    2
206   25    2
207   25    2
208   26    2
209   26    2
210   26    2
211   26    2
212   26    2
213   26    2
214   26    2
215   26    2
216   27    2
217   27    2
218   27    2
219   27    2
220   27    2
221   27    2
222   27    2
223   27    2
224   28    2
225   28    2
226   28    2
227   28    2
228   28    2
229   28    2
230   28    2
231   28    2
232   29    2
233   29    2
234   29    2
235   29    2
236   29    2
237   29    2
238   29    2
239   29    2
240   30    2
241   30    2
242   30    2
243   30    2
244   30    2
245   30    2
246   30    2
247   30    2
248   31    2
249   31    2
250   31    2
251   31    2
252   31    2
253   31    2
254   31    2
255   31    2
256   32    2
257   32    2
258   32    2
259   32    2
260   32    2
261   32    2
262   32    2
263   32    2
264   33    2
265   33    2
266   33    2
267   33    2
268   33    2
269   33    2
270   33    2
271   33    2
272   34    2
273   34    2
274   34    2
275   34    2
276   34    2
277   34    2
278   34    2
279   34    2
280   35    2
281   35    2
282   35    2
283   35    2
284   35    2
285   35    2
286   35    2
287   35    2
288   36    3
289   36    3
290   36    3
291   36    3
292   36    3
293   36    3
294   36    3
295   36    3
296   37    3
297   37    3
298   37    3
299   37    3
300   37    3
301   37    3
302   37    3
303   37    3
304   38    3
305   38    3
306   38    3
307   38    3
308   38    3
309   38    3
310   38    3
311   38    3
312   39    3
313   39    3
314   39    3
315   39    3
316   39    3
317   39    3
318   39    3
319   39    3
320   40    3
321   40    3
322   40    3
323   40    3
324   40    3
325   40    3
326   40    3
327   40    3
328   41    3
329   41    3
330   41    3
331   41    3
332   41    3
333   41    3
334   41    3
335   41    3
336   42    3
337   42    3
338   42    3
339   42    3
340   42    3
341   42    3
342   42    3
343   42    3
344   43    3
345   43    3
346   43    3
347   43    3
348   43    3
349   43    3
350   43    3
351   43    3
352   44    3
353   44    3
354   44    3
355   44    3
356   44    3
357   44    3
358   44    3
359   44    3
360   45    3
361   45    3
362   45    3
363   45    3
364   45    3
365   45    3
366   45    3
367   45    3
368   46    3
369   46    3
370   46    3
371   46    3
372   46    3
373   46    3
374   46    3
375   46    3
376   47    3
377   47    3
378   47    3
379   47    3
380   47    3
381   47    3
382   47    3
383   47    3

But when the user switches the mode to SMT-4, the output is

Show lscpu -e=CPU,CORE,NODE output (SMT-4)
CPU CORE NODE
  0    0    0
  1    0    0
  2    0    0
  3    0    0
  4    -    -
  5    -    -
  6    -    -
  7    -    -
  8    1    0
  9    1    0
 10    1    0
 11    1    0
 12    -    -
 13    -    -
 14    -    -
 15    -    -
 16    2    0
 17    2    0
 18    2    0
 19    2    0
 20    -    -
 21    -    -
 22    -    -
 23    -    -
 24    3    0
 25    3    0
 26    3    0
 27    3    0
 28    -    -
 29    -    -
 30    -    -
 31    -    -
 32    4    0
 33    4    0
 34    4    0
 35    4    0
 36    -    -
 37    -    -
 38    -    -
 39    -    -
 40    5    0
 41    5    0
 42    5    0
 43    5    0
 44    -    -
 45    -    -
 46    -    -
 47    -    -
 48    6    0
 49    6    0
 50    6    0
 51    6    0
 52    -    -
 53    -    -
 54    -    -
 55    -    -
 56    7    0
 57    7    0
 58    7    0
 59    7    0
 60    -    -
 61    -    -
 62    -    -
 63    -    -
 64    8    0
 65    8    0
 66    8    0
 67    8    0
 68    -    -
 69    -    -
 70    -    -
 71    -    -
 72    9    0
 73    9    0
 74    9    0
 75    9    0
 76    -    -
 77    -    -
 78    -    -
 79    -    -
 80   10    0
 81   10    0
 82   10    0
 83   10    0
 84    -    -
 85    -    -
 86    -    -
 87    -    -
 88   11    0
 89   11    0
 90   11    0
 91   11    0
 92    -    -
 93    -    -
 94    -    -
 95    -    -
 96   12    1
 97   12    1
 98   12    1
 99   12    1
100    -    -
101    -    -
102    -    -
103    -    -
104   13    1
105   13    1
106   13    1
107   13    1
108    -    -
109    -    -
110    -    -
111    -    -
112   14    1
113   14    1
114   14    1
115   14    1
116    -    -
117    -    -
118    -    -
119    -    -
120   15    1
121   15    1
122   15    1
123   15    1
124    -    -
125    -    -
126    -    -
127    -    -
128   16    1
129   16    1
130   16    1
131   16    1
132    -    -
133    -    -
134    -    -
135    -    -
136   17    1
137   17    1
138   17    1
139   17    1
140    -    -
141    -    -
142    -    -
143    -    -
144   18    1
145   18    1
146   18    1
147   18    1
148    -    -
149    -    -
150    -    -
151    -    -
152   19    1
153   19    1
154   19    1
155   19    1
156    -    -
157    -    -
158    -    -
159    -    -
160   20    1
161   20    1
162   20    1
163   20    1
164    -    -
165    -    -
166    -    -
167    -    -
168   21    1
169   21    1
170   21    1
171   21    1
172    -    -
173    -    -
174    -    -
175    -    -
176   22    1
177   22    1
178   22    1
179   22    1
180    -    -
181    -    -
182    -    -
183    -    -
184   23    1
185   23    1
186   23    1
187   23    1
188    -    -
189    -    -
190    -    -
191    -    -
192   24    2
193   24    2
194   24    2
195   24    2
196    -    -
197    -    -
198    -    -
199    -    -
200   25    2
201   25    2
202   25    2
203   25    2
204    -    -
205    -    -
206    -    -
207    -    -
208   26    2
209   26    2
210   26    2
211   26    2
212    -    -
213    -    -
214    -    -
215    -    -
216   27    2
217   27    2
218   27    2
219   27    2
220    -    -
221    -    -
222    -    -
223    -    -
224   28    2
225   28    2
226   28    2
227   28    2
228    -    -
229    -    -
230    -    -
231    -    -
232   29    2
233   29    2
234   29    2
235   29    2
236    -    -
237    -    -
238    -    -
239    -    -
240   30    2
241   30    2
242   30    2
243   30    2
244    -    -
245    -    -
246    -    -
247    -    -
248   31    2
249   31    2
250   31    2
251   31    2
252    -    -
253    -    -
254    -    -
255    -    -
256   32    2
257   32    2
258   32    2
259   32    2
260    -    -
261    -    -
262    -    -
263    -    -
264   33    2
265   33    2
266   33    2
267   33    2
268    -    -
269    -    -
270    -    -
271    -    -
272   34    2
273   34    2
274   34    2
275   34    2
276    -    -
277    -    -
278    -    -
279    -    -
280   35    2
281   35    2
282   35    2
283   35    2
284    -    -
285    -    -
286    -    -
287    -    -
288   36    3
289   36    3
290   36    3
291   36    3
292    -    -
293    -    -
294    -    -
295    -    -
296   37    3
297   37    3
298   37    3
299   37    3
300    -    -
301    -    -
302    -    -
303    -    -
304   38    3
305   38    3
306   38    3
307   38    3
308    -    -
309    -    -
310    -    -
311    -    -
312   39    3
313   39    3
314   39    3
315   39    3
316    -    -
317    -    -
318    -    -
319    -    -
320   40    3
321   40    3
322   40    3
323   40    3
324    -    -
325    -    -
326    -    -
327    -    -
328   41    3
329   41    3
330   41    3
331   41    3
332    -    -
333    -    -
334    -    -
335    -    -
336   42    3
337   42    3
338   42    3
339   42    3
340    -    -
341    -    -
342    -    -
343    -    -
344   43    3
345   43    3
346   43    3
347   43    3
348    -    -
349    -    -
350    -    -
351    -    -
352   44    3
353   44    3
354   44    3
355   44    3
356    -    -
357    -    -
358    -    -
359    -    -
360   45    3
361   45    3
362   45    3
363   45    3
364    -    -
365    -    -
366    -    -
367    -    -
368   46    3
369   46    3
370   46    3
371   46    3
372    -    -
373    -    -
374    -    -
375    -    -
376   47    3
377   47    3
378   47    3
379   47    3
380    -    -
381    -    -
382    -    -
383    -    -

Similarly if the user switches the mode to SMT-2, the output is (only pasting for Core 0)

~ lscpu -e=CPU,CORE,NODE
CPU CORE NODE
  0    0    0
  1    0    0
  2    -    -
  3    -    -
  4    -    -
  5    -    -
  6    -    -
  7    -    -

That's why when the logic tries to select from the last, it fails as it is not able to detect valid CPU Ids.

My earlier logic handled cpu id selection through this function

    def select_threads_per_power_core(self,
                                      node_cpu_ids: list[int]) -> list[int]:
        return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]

which handled all the cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really special... Why the disabled cores are visible to system...
Anyway, I updated the core selection logics, please check it, thanks:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this should work fine

Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, just some nits. PTAL!

else:
return obj_dict

assert platform.system() == "Linux"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check if the platform is linux before calling auto thread-binding, otherwise platform like MacOS may fail at this assertion.

@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 18, 2025
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
@vllm-bot vllm-bot merged commit e3a0e43 into vllm-project:main Jul 19, 2025
61 of 64 checks passed
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
@louie-tsai
Copy link
Contributor

could you remind me what is the original issue? is the issue introduced from ppc64le support?

@louie-tsai I think it probably is a original issue, as ppc64le implemented auto-bind in another method.

When we add -tp=4 and launch multiple workers, current thread binding lists are:

(VllmWorker rank=0 pid=1701267) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
(VllmWorker rank=1 pid=1701268) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 178,179,163,180,181,182,183,184,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55
(VllmWorker rank=2 pid=1701269) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 204,219,205,206,207,194,208,209,220,210,211,212,195,213,214,215,216,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78
(VllmWorker rank=3 pid=1701270) INFO 07-17 05:05:46 [cpu_worker.py:152] auto thread-binding list: 229,255,230,224,231,232,250,233,234,235,225,228,236,248,237,251,238,249,239,240,226,241,242,252,243,244,245,227,254,246,247,253

Only rank0 is as expected.

looks strange. didn't see that during my test, but maybe I haven't test all cases. thanks

LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants