Skip to content

Commit 95b6c76

Browse files
committed
add grpo support for third-party devices
1 parent 55e680e commit 95b6c76

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

trl/trainer/grpo_trainer.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
import torch.utils.data
2424
import transformers
25+
from accelerate import PartialState
2526
from accelerate.utils import broadcast_object_list, gather, gather_object
2627
from accelerate.utils.other import is_compiled_module
2728
from datasets import Dataset, IterableDataset
@@ -342,18 +343,20 @@ def data_collator(features): # No data collation is needed in GRPO
342343

343344
if self.accelerator.is_main_process:
344345
vllm_device = self.args.vllm_device
346+
device_type = PartialState().default_device.type
347+
device_module = getattr(torch, device_type)
345348
if vllm_device == "auto":
346-
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
349+
vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx
347350
# Check that the requested device is available
348-
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
351+
if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count():
349352
raise ValueError(
350353
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
351354
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
352355
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
353-
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
356+
f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
354357
)
355358
# Check that the requested device is not also used for training
356-
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
359+
if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
357360
warnings.warn(
358361
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
359362
"behavior. It is recommended to use a dedicated device for vLLM."

0 commit comments

Comments
 (0)