File tree 1 file changed +5
-2
lines changed
1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 22
22
import torch
23
23
import torch .utils .data
24
24
import transformers
25
+ from accelerate import PartialState
25
26
from accelerate .utils import broadcast_object_list , gather , gather_object
26
27
from accelerate .utils .other import is_compiled_module
27
28
from datasets import Dataset , IterableDataset
@@ -342,10 +343,12 @@ def data_collator(features): # No data collation is needed in GRPO
342
343
343
344
if self .accelerator .is_main_process :
344
345
vllm_device = self .args .vllm_device
346
+ device_type = PartialState ().default_device .type
347
+ device_module = getattr (torch , device_type )
345
348
if vllm_device == "auto" :
346
- vllm_device = f"cuda :{ self .accelerator .num_processes } " # take the next GPU idx
349
+ vllm_device = f"{ device_type } :{ self .accelerator .num_processes } " # take the next GPU idx
347
350
# Check that the requested device is available
348
- if vllm_device .split (":" )[0 ] == "cuda " and int (vllm_device .split (":" )[1 ]) >= torch . cuda .device_count ():
351
+ if vllm_device .split (":" )[0 ] == f" { device_type } " and int (vllm_device .split (":" )[1 ]) >= device_module .device_count ():
349
352
raise ValueError (
350
353
f"The requested device for vllm ({ vllm_device } ) is not available. You are likely using vLLM "
351
354
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
You can’t perform that action at this time.
0 commit comments