Skip to content

Commit 9354458

Browse files
committed
Add accelerator API to RPC distributed examples:
- ddp_rpc - parameter_server - rnn Signed-off-by: jafraustro <[email protected]>
1 parent de7db4c commit 9354458

File tree

4 files changed

+39
-24
lines changed

4 files changed

+39
-24
lines changed

distributed/rpc/ddp_rpc/main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ class HybridModel(torch.nn.Module):
2727
def __init__(self, remote_emb_module, device):
2828
super(HybridModel, self).__init__()
2929
self.remote_emb_module = remote_emb_module
30-
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
30+
self.fc = DDP(torch.nn.Linear(16, 8).to(device), device_ids=[device])
3131
self.device = device
3232

3333
def forward(self, indices, offsets):
34+
device = torch.accelerator.current_accelerator()
3435
emb_lookup = self.remote_emb_module.forward(indices, offsets)
35-
return self.fc(emb_lookup.cuda(self.device))
36+
return self.fc(emb_lookup.to(self.device))
3637

3738

3839
def _run_trainer(remote_emb_module, rank):
@@ -83,7 +84,7 @@ def get_next_batch(rank):
8384
batch_size += 1
8485

8586
offsets_tensor = torch.LongTensor(offsets)
86-
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
87+
target = torch.LongTensor(batch_size).random_(8).to(rank)
8788
yield indices, offsets_tensor, target
8889

8990
# Train for 100 epochs
@@ -145,9 +146,15 @@ def run_worker(rank, world_size):
145146
for fut in futs:
146147
fut.wait()
147148
elif rank <= 1:
149+
if torch.accelerator.is_available():
150+
acc = torch.accelerator.current_accelerator()
151+
device = torch.device(acc)
152+
else:
153+
device = torch.device("cpu")
154+
backend = torch.distributed.get_default_backend_for_device(device)
148155
# Initialize process group for Distributed DataParallel on trainers.
149156
dist.init_process_group(
150-
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
157+
backend=backend, rank=rank, world_size=2, init_method="tcp://localhost:29500"
151158
)
152159

153160
# Initialize RPC.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torch>=1.6.0
1+
torch>=2.7.1

distributed/rpc/parameter_server/rpc_parameter_server.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,19 @@ def __init__(self, num_gpus=0):
2020
super(Net, self).__init__()
2121
print(f"Using {num_gpus} GPUs to train")
2222
self.num_gpus = num_gpus
23-
device = torch.device(
24-
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
23+
if torch.accelerator.is_available() and self.num_gpus > 0:
24+
acc = torch.accelerator.current_accelerator()
25+
device = torch.device(f'{acc}:0')
26+
else:
27+
device = torch.device("cpu")
2528
print(f"Putting first 2 convs on {str(device)}")
26-
# Put conv layers on the first cuda device
29+
# Put conv layers on the first accelerator device
2730
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
2831
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
29-
# Put rest of the network on the 2nd cuda device, if there is one
30-
if "cuda" in str(device) and num_gpus > 1:
31-
device = torch.device("cuda:1")
32+
# Put rest of the network on the 2nd accelerator device, if there is one
33+
if torch.accelerator.is_available() and self.num_gpus > 0:
34+
acc = torch.accelerator.current_accelerator()
35+
device = torch.device(f'{acc}:1')
3236

3337
print(f"Putting rest of layers on {str(device)}")
3438
self.dropout1 = nn.Dropout2d(0.25).to(device)
@@ -72,21 +76,22 @@ def call_method(method, rref, *args, **kwargs):
7276
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
7377
# back.
7478

75-
7679
def remote_method(method, rref, *args, **kwargs):
7780
args = [method, rref] + list(args)
7881
return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)
7982

80-
8183
# --------- Parameter Server --------------------
8284
class ParameterServer(nn.Module):
8385
def __init__(self, num_gpus=0):
8486
super().__init__()
8587
model = Net(num_gpus=num_gpus)
8688
self.model = model
87-
self.input_device = torch.device(
88-
"cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")
89-
89+
if torch.accelerator.is_available() and num_gpus > 0:
90+
acc = torch.accelerator.current_accelerator()
91+
self.input_device = torch.device(f'{acc}:0')
92+
else:
93+
self.input_device = torch.device("cpu")
94+
9095
def forward(self, inp):
9196
inp = inp.to(self.input_device)
9297
out = self.model(inp)
@@ -113,11 +118,9 @@ def get_param_rrefs(self):
113118
param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
114119
return param_rrefs
115120

116-
117121
param_server = None
118122
global_lock = Lock()
119123

120-
121124
def get_parameter_server(num_gpus=0):
122125
global param_server
123126
# Ensure that we get only one handle to the ParameterServer.
@@ -197,8 +200,11 @@ def get_accuracy(test_loader, model):
197200
model.eval()
198201
correct_sum = 0
199202
# Use GPU to evaluate if possible
200-
device = torch.device("cuda:0" if model.num_gpus > 0
201-
and torch.cuda.is_available() else "cpu")
203+
if torch.accelerator.is_available() and model.num_gpus > 0:
204+
acc = torch.accelerator.current_accelerator()
205+
device = torch.device(f'{acc}:0')
206+
else:
207+
device = torch.device("cpu")
202208
with torch.no_grad():
203209
for i, (data, target) in enumerate(test_loader):
204210
out = model(data)

distributed/rpc/rnn/rnn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ def __init__(self, ntoken, ninp, dropout):
4343
super(EmbeddingTable, self).__init__()
4444
self.drop = nn.Dropout(dropout)
4545
self.encoder = nn.Embedding(ntoken, ninp)
46-
if torch.cuda.is_available():
47-
self.encoder = self.encoder.cuda()
46+
if torch.accelerator.is_available():
47+
device = torch.accelerator.current_accelerator()
48+
self.encoder = self.encoder.to(device)
4849
nn.init.uniform_(self.encoder.weight, -0.1, 0.1)
4950

5051
def forward(self, input):
51-
if torch.cuda.is_available():
52-
input = input.cuda()
52+
if torch.accelerator.is_available():
53+
device = torch.accelerator.current_accelerator()
54+
input = input.to(device)
5355
return self.drop(self.encoder(input)).cpu()
5456

5557

0 commit comments

Comments
 (0)