@@ -20,15 +20,19 @@ def __init__(self, num_gpus=0):
20
20
super (Net , self ).__init__ ()
21
21
print (f"Using { num_gpus } GPUs to train" )
22
22
self .num_gpus = num_gpus
23
- device = torch .device (
24
- "cuda:0" if torch .cuda .is_available () and self .num_gpus > 0 else "cpu" )
23
+ if torch .accelerator .is_available () and self .num_gpus > 0 :
24
+ acc = torch .accelerator .current_accelerator ()
25
+ device = torch .device (f'{ acc } :0' )
26
+ else :
27
+ device = torch .device ("cpu" )
25
28
print (f"Putting first 2 convs on { str (device )} " )
26
- # Put conv layers on the first cuda device
29
+ # Put conv layers on the first accelerator device
27
30
self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 ).to (device )
28
31
self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 ).to (device )
29
- # Put rest of the network on the 2nd cuda device, if there is one
30
- if "cuda" in str (device ) and num_gpus > 1 :
31
- device = torch .device ("cuda:1" )
32
+ # Put rest of the network on the 2nd accelerator device, if there is one
33
+ if torch .accelerator .is_available () and self .num_gpus > 0 :
34
+ acc = torch .accelerator .current_accelerator ()
35
+ device = torch .device (f'{ acc } :1' )
32
36
33
37
print (f"Putting rest of layers on { str (device )} " )
34
38
self .dropout1 = nn .Dropout2d (0.25 ).to (device )
@@ -72,21 +76,22 @@ def call_method(method, rref, *args, **kwargs):
72
76
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
73
77
# back.
74
78
75
-
76
79
def remote_method (method , rref , * args , ** kwargs ):
77
80
args = [method , rref ] + list (args )
78
81
return rpc .rpc_sync (rref .owner (), call_method , args = args , kwargs = kwargs )
79
82
80
-
81
83
# --------- Parameter Server --------------------
82
84
class ParameterServer (nn .Module ):
83
85
def __init__ (self , num_gpus = 0 ):
84
86
super ().__init__ ()
85
87
model = Net (num_gpus = num_gpus )
86
88
self .model = model
87
- self .input_device = torch .device (
88
- "cuda:0" if torch .cuda .is_available () and num_gpus > 0 else "cpu" )
89
-
89
+ if torch .accelerator .is_available () and num_gpus > 0 :
90
+ acc = torch .accelerator .current_accelerator ()
91
+ self .input_device = torch .device (f'{ acc } :0' )
92
+ else :
93
+ self .input_device = torch .device ("cpu" )
94
+
90
95
def forward (self , inp ):
91
96
inp = inp .to (self .input_device )
92
97
out = self .model (inp )
@@ -113,11 +118,9 @@ def get_param_rrefs(self):
113
118
param_rrefs = [rpc .RRef (param ) for param in self .model .parameters ()]
114
119
return param_rrefs
115
120
116
-
117
121
param_server = None
118
122
global_lock = Lock ()
119
123
120
-
121
124
def get_parameter_server (num_gpus = 0 ):
122
125
global param_server
123
126
# Ensure that we get only one handle to the ParameterServer.
@@ -197,8 +200,11 @@ def get_accuracy(test_loader, model):
197
200
model .eval ()
198
201
correct_sum = 0
199
202
# Use GPU to evaluate if possible
200
- device = torch .device ("cuda:0" if model .num_gpus > 0
201
- and torch .cuda .is_available () else "cpu" )
203
+ if torch .accelerator .is_available () and model .num_gpus > 0 :
204
+ acc = torch .accelerator .current_accelerator ()
205
+ device = torch .device (f'{ acc } :0' )
206
+ else :
207
+ device = torch .device ("cpu" )
202
208
with torch .no_grad ():
203
209
for i , (data , target ) in enumerate (test_loader ):
204
210
out = model (data )
0 commit comments