Skip to content

Commit a84f91c

Browse files
committed
Enhance GPU verification and cleanup in DDP RPC example
- Added a function to verify minimum GPU count before execution. - Updated HybridModel initialization to use rank instead of device. - Ensured proper cleanup of the process group to avoid resource leaks. - Added exit message if insufficient GPUs are detected. Signed-off-by: jafraustro <[email protected]>
1 parent a790549 commit a84f91c

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

distributed/rpc/ddp_rpc/main.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
NUM_EMBEDDINGS = 100
1616
EMBEDDING_DIM = 16
1717

18+
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
19+
""" verification that we have at least 2 gpus to run dist examples """
20+
has_gpu = torch.accelerator.is_available()
21+
gpu_count = torch.accelerator.device_count()
22+
return has_gpu and gpu_count >= min_gpus
1823

1924
class HybridModel(torch.nn.Module):
2025
r"""
@@ -24,16 +29,15 @@ class HybridModel(torch.nn.Module):
2429
This remote model can get a Remote Reference to the embedding table on the parameter server.
2530
"""
2631

27-
def __init__(self, remote_emb_module, device):
32+
def __init__(self, remote_emb_module, rank):
2833
super(HybridModel, self).__init__()
2934
self.remote_emb_module = remote_emb_module
30-
self.fc = DDP(torch.nn.Linear(16, 8).to(device), device_ids=[device])
31-
self.device = device
35+
self.fc = DDP(torch.nn.Linear(16, 8).to(rank))
36+
self.rank = rank
3237

3338
def forward(self, indices, offsets):
34-
device = torch.accelerator.current_accelerator()
3539
emb_lookup = self.remote_emb_module.forward(indices, offsets)
36-
return self.fc(emb_lookup.to(self.device))
40+
return self.fc(emb_lookup.to(self.rank))
3741

3842

3943
def _run_trainer(remote_emb_module, rank):
@@ -152,6 +156,7 @@ def run_worker(rank, world_size):
152156
else:
153157
device = torch.device("cpu")
154158
backend = torch.distributed.get_default_backend_for_device(device)
159+
torch.accelerator.device_index(rank)
155160
# Initialize process group for Distributed DataParallel on trainers.
156161
dist.init_process_group(
157162
backend=backend, rank=rank, world_size=2, init_method="tcp://localhost:29500"
@@ -179,9 +184,18 @@ def run_worker(rank, world_size):
179184

180185
# block until all rpcs finish
181186
rpc.shutdown()
187+
188+
# Clean up process group for trainers to avoid resource leaks
189+
if rank <= 1:
190+
dist.destroy_process_group()
182191

183192

184193
if __name__ == "__main__":
185194
# 2 trainers, 1 parameter server, 1 master.
186195
world_size = 4
196+
_min_gpu_count = 2
197+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
198+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
199+
exit()
187200
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
201+
print("Distributed RPC example completed successfully.")

0 commit comments

Comments
 (0)