|
22 | 22 | import torch
|
23 | 23 | import torch.utils.data
|
24 | 24 | import transformers
|
| 25 | +from accelerate import PartialState |
25 | 26 | from accelerate.utils import broadcast_object_list, gather, gather_object
|
26 | 27 | from accelerate.utils.other import is_compiled_module
|
27 | 28 | from datasets import Dataset, IterableDataset
|
@@ -342,18 +343,21 @@ def data_collator(features): # No data collation is needed in GRPO
|
342 | 343 |
|
343 | 344 | if self.accelerator.is_main_process:
|
344 | 345 | vllm_device = self.args.vllm_device
|
| 346 | + device_type = PartialState().default_device.type |
| 347 | + device_module = getattr(torch, device_type) |
345 | 348 | if vllm_device == "auto":
|
346 |
| - vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx |
| 349 | + vllm_device = f"{device_type}:{self.accelerator.num_processes}" # take the next GPU idx |
| 350 | + self.args.vllm_device = vllm_device |
347 | 351 | # Check that the requested device is available
|
348 |
| - if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): |
| 352 | + if vllm_device.split(":")[0] == f"{device_type}" and int(vllm_device.split(":")[1]) >= device_module.device_count(): |
349 | 353 | raise ValueError(
|
350 | 354 | f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
351 | 355 | "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
352 | 356 | "value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
353 |
| - f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`." |
| 357 | + f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`." |
354 | 358 | )
|
355 | 359 | # Check that the requested device is not also used for training
|
356 |
| - if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: |
| 360 | + if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}: |
357 | 361 | warnings.warn(
|
358 | 362 | f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
|
359 | 363 | "behavior. It is recommended to use a dedicated device for vLLM."
|
@@ -470,6 +474,11 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
|
470 | 474 | else:
|
471 | 475 | state_dict = unwrapped_model.state_dict()
|
472 | 476 | if self.accelerator.is_main_process:
|
| 477 | + if PartialState().default_device.type == "npu": |
| 478 | + # For Ascend NPUs, torch.Tensor.copy_ does not support cross-device tensor copy |
| 479 | + for k, v in state_dict.items(): |
| 480 | + if isinstance(v, torch.tensor): |
| 481 | + state_dict[k] = v.to("cpu").to(self.args.vllm_device) |
473 | 482 | llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
474 | 483 | llm_model.load_weights(state_dict.items())
|
475 | 484 | self._last_loaded_step = self.state.global_step
|
|
0 commit comments