Skip to content

Add accelerator API to RPC distributed examples: ddp_rpc, parameter_server, rnn #1371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions distributed/rpc/ddp_rpc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
NUM_EMBEDDINGS = 100
EMBEDDING_DIM = 16

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus

class HybridModel(torch.nn.Module):
r"""
Expand All @@ -24,15 +29,15 @@ class HybridModel(torch.nn.Module):
This remote model can get a Remote Reference to the embedding table on the parameter server.
"""

def __init__(self, remote_emb_module, device):
def __init__(self, remote_emb_module, rank):
super(HybridModel, self).__init__()
self.remote_emb_module = remote_emb_module
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
self.device = device
self.fc = DDP(torch.nn.Linear(16, 8).to(rank))
self.rank = rank

def forward(self, indices, offsets):
emb_lookup = self.remote_emb_module.forward(indices, offsets)
return self.fc(emb_lookup.cuda(self.device))
return self.fc(emb_lookup.to(self.rank))


def _run_trainer(remote_emb_module, rank):
Expand Down Expand Up @@ -83,7 +88,7 @@ def get_next_batch(rank):
batch_size += 1

offsets_tensor = torch.LongTensor(offsets)
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
target = torch.LongTensor(batch_size).random_(8).to(rank)
yield indices, offsets_tensor, target

# Train for 100 epochs
Expand Down Expand Up @@ -145,9 +150,16 @@ def run_worker(rank, world_size):
for fut in futs:
fut.wait()
elif rank <= 1:
if torch.accelerator.is_available():
acc = torch.accelerator.current_accelerator()
device = torch.device(acc)
else:
device = torch.device("cpu")
backend = torch.distributed.get_default_backend_for_device(device)
torch.accelerator.device_index(rank)
# Initialize process group for Distributed DataParallel on trainers.
dist.init_process_group(
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
backend=backend, rank=rank, world_size=2, init_method="tcp://localhost:29500"
)

# Initialize RPC.
Expand All @@ -172,9 +184,18 @@ def run_worker(rank, world_size):

# block until all rpcs finish
rpc.shutdown()

# Clean up process group for trainers to avoid resource leaks
if rank <= 1:
dist.destroy_process_group()


if __name__ == "__main__":
# 2 trainers, 1 parameter server, 1 master.
world_size = 4
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
exit()
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
print("Distributed RPC example completed successfully.")
3 changes: 2 additions & 1 deletion distributed/rpc/ddp_rpc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torch>=1.6.0
torch>=2.7.1
numpy
2 changes: 2 additions & 0 deletions distributed/rpc/parameter_server/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch>=2.7.1
numpy
36 changes: 21 additions & 15 deletions distributed/rpc/parameter_server/rpc_parameter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ def __init__(self, num_gpus=0):
super(Net, self).__init__()
print(f"Using {num_gpus} GPUs to train")
self.num_gpus = num_gpus
device = torch.device(
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
if torch.accelerator.is_available() and self.num_gpus > 0:
acc = torch.accelerator.current_accelerator()
device = torch.device(f'{acc}:0')
else:
device = torch.device("cpu")
print(f"Putting first 2 convs on {str(device)}")
# Put conv layers on the first cuda device
# Put conv layers on the first accelerator device
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
# Put rest of the network on the 2nd cuda device, if there is one
if "cuda" in str(device) and num_gpus > 1:
device = torch.device("cuda:1")
# Put rest of the network on the 2nd accelerator device, if there is one
if torch.accelerator.is_available() and self.num_gpus > 0:
acc = torch.accelerator.current_accelerator()
device = torch.device(f'{acc}:1')

print(f"Putting rest of layers on {str(device)}")
self.dropout1 = nn.Dropout2d(0.25).to(device)
Expand Down Expand Up @@ -72,21 +76,22 @@ def call_method(method, rref, *args, **kwargs):
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
# back.


def remote_method(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)


# --------- Parameter Server --------------------
class ParameterServer(nn.Module):
def __init__(self, num_gpus=0):
super().__init__()
model = Net(num_gpus=num_gpus)
self.model = model
self.input_device = torch.device(
"cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")

if torch.accelerator.is_available() and num_gpus > 0:
acc = torch.accelerator.current_accelerator()
self.input_device = torch.device(f'{acc}:0')
else:
self.input_device = torch.device("cpu")

def forward(self, inp):
inp = inp.to(self.input_device)
out = self.model(inp)
Expand All @@ -113,11 +118,9 @@ def get_param_rrefs(self):
param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
return param_rrefs


param_server = None
global_lock = Lock()


def get_parameter_server(num_gpus=0):
global param_server
# Ensure that we get only one handle to the ParameterServer.
Expand Down Expand Up @@ -197,8 +200,11 @@ def get_accuracy(test_loader, model):
model.eval()
correct_sum = 0
# Use GPU to evaluate if possible
device = torch.device("cuda:0" if model.num_gpus > 0
and torch.cuda.is_available() else "cpu")
if torch.accelerator.is_available() and model.num_gpus > 0:
acc = torch.accelerator.current_accelerator()
device = torch.device(f'{acc}:0')
else:
device = torch.device("cpu")
with torch.no_grad():
for i, (data, target) in enumerate(test_loader):
out = model(data)
Expand Down
3 changes: 2 additions & 1 deletion distributed/rpc/rnn/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torch
torch>=2.7.1
numpy
10 changes: 6 additions & 4 deletions distributed/rpc/rnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ def __init__(self, ntoken, ninp, dropout):
super(EmbeddingTable, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator()
self.encoder = self.encoder.to(device)
nn.init.uniform_(self.encoder.weight, -0.1, 0.1)

def forward(self, input):
if torch.cuda.is_available():
input = input.cuda()
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator()
input = input.to(device)
return self.drop(self.encoder(input)).cpu()


Expand Down
10 changes: 10 additions & 0 deletions run_distributed_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,20 @@ function distributed_minGPT-ddp() {
uv run bash run_example.sh mingpt/main.py || error "minGPT example failed"
}

function distributed_rpc_ddp_rpc() {
uv run main.py || error "ddp_rpc example failed"
}

function distributed_rpc_rnn() {
uv run main.py || error "rpc_rnn example failed"
}

function run_all() {
run distributed/tensor_parallelism
run distributed/ddp
run distributed/minGPT-ddp
run distributed/rpc/ddp_rpc
run distributed/rpc/rnn
}

# by default, run all examples
Expand Down
2 changes: 1 addition & 1 deletion utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function run() {
if start $EXAMPLE; then
# drop trailing slash (occurs due to auto completion in bash interactive mode)
# replace slashes with underscores: this allows to call nested examples
EXAMPLE_FN=$(echo $EXAMPLE | sed "s@/\$@@" | sed 's@/@_@')
EXAMPLE_FN=$(echo $EXAMPLE | sed "s@/\$@@" | sed 's@/@_@g')
$EXAMPLE_FN
fi
stop $EXAMPLE
Expand Down