@@ -233,6 +233,8 @@ def device_sync(device="cpu"):
233
233
torch .cuda .synchronize (device )
234
234
elif "xpu" in device :
235
235
torch .xpu .synchronize (device )
236
+ elif "npu" in device :
237
+ torch .npu .synchronize (device )
236
238
elif ("cpu" in device ) or ("mps" in device ):
237
239
pass
238
240
else :
@@ -275,33 +277,36 @@ def is_mps_available() -> bool:
275
277
# MPS, is that you?
276
278
return True
277
279
280
+ def select_device (device ) -> str :
281
+ if torch .cuda .is_available ():
282
+ return "cuda"
283
+ elif is_mps_available ():
284
+ return "mps"
285
+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
286
+ return "npu"
287
+ elif torch .xpu .is_available ():
288
+ return "xpu"
289
+ else :
290
+ return "cpu"
291
+
278
292
279
293
def get_device_str (device ) -> str :
280
294
if isinstance (device , str ) and device == "fast" :
281
- device = (
282
- "cuda"
283
- if torch .cuda .is_available ()
284
- else "mps" if is_mps_available ()
285
- else "xpu" if torch .xpu .is_available () else "cpu"
286
- )
295
+ device = select_device (device )
287
296
return device
288
297
else :
289
298
return str (device )
290
299
291
300
292
301
def get_device (device ) -> str :
293
302
if isinstance (device , str ) and device == "fast" :
294
- device = (
295
- "cuda"
296
- if torch .cuda .is_available ()
297
- else "mps" if is_mps_available ()
298
- else "xpu" if torch .xpu .is_available () else "cpu"
299
- )
303
+ device = select_device (device )
300
304
return torch .device (device )
301
305
302
306
303
307
def is_cpu_device (device ) -> bool :
304
308
return device == "" or str (device ) == "cpu"
305
309
306
- def is_cuda_or_cpu_or_xpu_device (device ) -> bool :
307
- return is_cpu_device (device ) or ("cuda" in str (device )) or ("xpu" in str (device ))
310
+ def is_supported_device (device ) -> bool :
311
+ device_str = str (device )
312
+ return is_cpu_device (device ) or any (dev in device_str for dev in ('cuda' , 'xpu' , 'npu' ))
0 commit comments