Skip to content

Commit d91896a

Browse files
committed
add grpo support for third-party devices
1 parent 55e680e commit d91896a

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

trl/trainer/grpo_trainer.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
import torch.utils.data
2424
import transformers
25+
from accelerate import PartialState
2526
from accelerate.utils import broadcast_object_list, gather, gather_object
2627
from accelerate.utils.other import is_compiled_module
2728
from datasets import Dataset, IterableDataset
@@ -342,18 +343,21 @@ def data_collator(features): # No data collation is needed in GRPO
342343

343344
if self.accelerator.is_main_process:
344345
vllm_device = self.args.vllm_device
346+
device_type = PartialState().default_device.type
347+
device_module = getattr(torch, device_type)
345348
if vllm_device == "auto":
346-
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
349+
vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx
350+
self.args.vllm_device = vllm_device
347351
# Check that the requested device is available
348-
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
352+
if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count():
349353
raise ValueError(
350354
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
351355
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
352356
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
353-
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
357+
f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
354358
)
355359
# Check that the requested device is not also used for training
356-
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
360+
if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
357361
warnings.warn(
358362
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
359363
"behavior. It is recommended to use a dedicated device for vLLM."
@@ -470,6 +474,11 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
470474
else:
471475
state_dict = unwrapped_model.state_dict()
472476
if self.accelerator.is_main_process:
477+
if PartialState().default_device.type == "npu":
478+
# For Ascend NPUs, torch.Tensor.copy_ does not support cross-device tensor copy
479+
for k, v in state_dict.items():
480+
if isinstance(v, torch.tensor):
481+
state_dict[k] = v.to("cpu").to(self.args.vllm_device)
473482
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
474483
llm_model.load_weights(state_dict.items())
475484
self._last_loaded_step = self.state.global_step

0 commit comments

Comments
 (0)