Skip to content

Commit ff4b307

Browse files
committed
Enhance GPU verification and cleanup in DDP RPC example
- Added a function to verify minimum GPU count before execution. - Updated HybridModel initialization to use rank instead of device. - Ensured proper cleanup of the process group to avoid resource leaks. - Added exit message if insufficient GPUs are detected. Signed-off-by: jafraustro <[email protected]>
1 parent a790549 commit ff4b307

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

distributed/rpc/ddp_rpc/main.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
NUM_EMBEDDINGS = 100
1616
EMBEDDING_DIM = 16
1717

18+
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
19+
""" verification that we have at least 2 gpus to run dist examples """
20+
has_gpu = torch.accelerator.is_available()
21+
gpu_count = torch.accelerator.device_count()
22+
return has_gpu and gpu_count >= min_gpus
1823

1924
class HybridModel(torch.nn.Module):
2025
r"""
@@ -24,16 +29,16 @@ class HybridModel(torch.nn.Module):
2429
This remote model can get a Remote Reference to the embedding table on the parameter server.
2530
"""
2631

27-
def __init__(self, remote_emb_module, device):
32+
def __init__(self, remote_emb_module, rank):
2833
super(HybridModel, self).__init__()
2934
self.remote_emb_module = remote_emb_module
30-
self.fc = DDP(torch.nn.Linear(16, 8).to(device), device_ids=[device])
31-
self.device = device
35+
self.fc = DDP(torch.nn.Linear(16, 8).to(rank))
36+
self.rank = rank
3237

3338
def forward(self, indices, offsets):
34-
device = torch.accelerator.current_accelerator()
39+
# device = torch.cuda.current_device()
3540
emb_lookup = self.remote_emb_module.forward(indices, offsets)
36-
return self.fc(emb_lookup.to(self.device))
41+
return self.fc(emb_lookup.to(self.rank))
3742

3843

3944
def _run_trainer(remote_emb_module, rank):
@@ -152,6 +157,7 @@ def run_worker(rank, world_size):
152157
else:
153158
device = torch.device("cpu")
154159
backend = torch.distributed.get_default_backend_for_device(device)
160+
torch.accelerator.device_index(rank)
155161
# Initialize process group for Distributed DataParallel on trainers.
156162
dist.init_process_group(
157163
backend=backend, rank=rank, world_size=2, init_method="tcp://localhost:29500"
@@ -179,9 +185,18 @@ def run_worker(rank, world_size):
179185

180186
# block until all rpcs finish
181187
rpc.shutdown()
188+
189+
# Clean up process group for trainers to avoid resource leaks
190+
if rank <= 1:
191+
dist.destroy_process_group()
182192

183193

184194
if __name__ == "__main__":
185195
# 2 trainers, 1 parameter server, 1 master.
186196
world_size = 4
197+
_min_gpu_count = 2
198+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
199+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
200+
exit()
187201
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
202+
print("Distributed RPC example completed successfully.")

0 commit comments

Comments
 (0)