|
22 | 22 | import torch
|
23 | 23 | import torch.utils.data
|
24 | 24 | import transformers
|
| 25 | +from accelerate import PartialState |
25 | 26 | from accelerate.utils import broadcast_object_list, gather, gather_object
|
26 | 27 | from accelerate.utils.other import is_compiled_module
|
27 | 28 | from datasets import Dataset, IterableDataset
|
@@ -342,18 +343,20 @@ def data_collator(features): # No data collation is needed in GRPO
|
342 | 343 |
|
343 | 344 | if self.accelerator.is_main_process:
|
344 | 345 | vllm_device = self.args.vllm_device
|
| 346 | + device_type = PartialState().default_device.type |
| 347 | + device_module = getattr(torch, device_type) |
345 | 348 | if vllm_device == "auto":
|
346 |
| - vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx |
| 349 | + vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx |
347 | 350 | # Check that the requested device is available
|
348 |
| - if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): |
| 351 | + if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count(): |
349 | 352 | raise ValueError(
|
350 | 353 | f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
351 | 354 | "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
352 | 355 | "value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
353 |
| - f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`." |
| 356 | + f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`." |
354 | 357 | )
|
355 | 358 | # Check that the requested device is not also used for training
|
356 |
| - if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: |
| 359 | + if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}: |
357 | 360 | warnings.warn(
|
358 | 361 | f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
|
359 | 362 | "behavior. It is recommended to use a dedicated device for vLLM."
|
|
0 commit comments