File tree 4 files changed +16
-5
lines changed
4 files changed +16
-5
lines changed Original file line number Diff line number Diff line change @@ -78,7 +78,10 @@ def __init__(
78
78
self ._backward_sync_control = _DDPBackwardSyncControl ()
79
79
self ._ddp_kwargs = kwargs
80
80
81
- self .device_type = self .root_device .type
81
+ if isinstance (self .accelerator , Accelerator ):
82
+ self .device_type = self .accelerator .get_device_type ()
83
+ else :
84
+ self .device_type = "cuda"
82
85
self .torch_lib = getattr (torch , self .device_type )
83
86
84
87
@property
Original file line number Diff line number Diff line change @@ -299,7 +299,10 @@ def __init__(
299
299
300
300
self ._deepspeed_engine : Optional [DeepSpeedEngine ] = None
301
301
302
- self .device_type = self .root_device .type
302
+ if isinstance (self .accelerator , Accelerator ):
303
+ self .device_type = self .accelerator .get_device_type ()
304
+ else :
305
+ self .device_type = "cuda"
303
306
self .torch_lib = getattr (torch , self .device_type )
304
307
305
308
@property
Original file line number Diff line number Diff line change @@ -103,8 +103,10 @@ def __init__(
103
103
self ._timeout : Optional [timedelta ] = timeout
104
104
self ._start_method = start_method
105
105
106
- self .device_type = self .root_device .type
107
- self .torch_lib = getattr (torch , self .device_type )
106
+ try :
107
+ self .device_type = self .accelerator .get_device_type ()
108
+ except Exception :
109
+ self .device_type = "cuda"
108
110
109
111
@property
110
112
def is_distributed (self ) -> bool : # pragma: no-cover
Original file line number Diff line number Diff line change @@ -319,7 +319,10 @@ def __init__(
319
319
self .hysteresis = hysteresis
320
320
self .min_loss_scale = min_loss_scale
321
321
322
- self .device_type = self .root_device .type
322
+ try :
323
+ self .device_type = self .accelerator .get_device_type ()
324
+ except Exception :
325
+ self .device_type = "cuda"
323
326
self .torch_lib = getattr (torch , self .device_type )
324
327
325
328
@override
You can’t perform that action at this time.
0 commit comments