-
Notifications
You must be signed in to change notification settings - Fork 3
Single Run DDP Example Script #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LeoRoccoBreedt
wants to merge
86
commits into
main
Choose a base branch
from
lb/pytorch_ddp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
86 commits
Select commit
Hold shift + click to select a range
b837338
feat: Added initial Pytorch example to monitor batching and per layer…
LeoRoccoBreedt 9994528
refactor: Update introduction section for more clarity on notebook use
LeoRoccoBreedt 223fde1
chore: change how the custom run id gets automatically generated
LeoRoccoBreedt ceb3a24
chore: update instructions on how users can get and set their API tok…
LeoRoccoBreedt 39f5bb9
chore: update the introduction to be more foundation model training o…
LeoRoccoBreedt 69fdfba
chore: update dataset section with a better description
LeoRoccoBreedt 1cde287
refactor: update training loop where grads, norms and activations are…
LeoRoccoBreedt 2da28f4
refactor: update batch size and edit gradient norm logging code
LeoRoccoBreedt 123fea2
chore: add data file to ignore for pytorch example
LeoRoccoBreedt 86b6d5d
refactor: update model architecture layers and update training loop
LeoRoccoBreedt a39077a
refactor: update accuracy calculation to not output the percentage
LeoRoccoBreedt 3e2d453
refactor: update model architecture layers, accuracy calculation and …
LeoRoccoBreedt 8a3da57
feat: Added a pytorch text-based example that is used to demonstrate …
LeoRoccoBreedt d03fd43
refactor: add validation and test loss calcualtion for each epoch
LeoRoccoBreedt b7f8e08
refactor: update configs and parameters
LeoRoccoBreedt b113a36
refactor: update logged configs
LeoRoccoBreedt 7258993
refactor: calculate activations per layer
LeoRoccoBreedt ae330e6
refactor: add tracking for grad norms
LeoRoccoBreedt d4d160c
refactor: add gradient tracking per epoch
LeoRoccoBreedt 890bf26
chore: remove uneeded section
LeoRoccoBreedt ebd34c7
refactor: add fully connected layer to model for more complexity
LeoRoccoBreedt ae9abc6
chore: fix activation saving for layers
LeoRoccoBreedt 6f23761
refactor: update packages for example
LeoRoccoBreedt 2c82a0f
refactor: update dataset to be used in example
LeoRoccoBreedt 48451d6
refactor: update evalution function that calculates the validation lo…
LeoRoccoBreedt 5b5fbdd
refactor: update training loop to work with new data
LeoRoccoBreedt 8ff6f30
chore: remove unused sections
LeoRoccoBreedt b32b27b
chore: re-organize notebook layout
LeoRoccoBreedt 4d659a2
chore: cleanup and add parameters in right place
LeoRoccoBreedt 5e7ec6e
refactor: add all debugging metrics to the same dictionary variable
LeoRoccoBreedt d27c247
chore: change LSTM layers to see response in logging
LeoRoccoBreedt dc4b63d
refactor: update location where model.train() is called in training loop
LeoRoccoBreedt 1a8d021
refactor: update how HF dataset is downloaded to only download a subset
LeoRoccoBreedt fd7cbc5
refactor: update loading and processing of HF dataset to make code fa…
LeoRoccoBreedt 3be6c3e
chore: add TODO's to address
LeoRoccoBreedt 6e1df27
chore: update sections
LeoRoccoBreedt fbfb0da
refactor: move data downloading section
LeoRoccoBreedt 84149d1
refactor: added evaluate function to model initialization cell
LeoRoccoBreedt 2d69e01
chore: update introduction for the model architecture and helper func…
LeoRoccoBreedt 90e0539
fix: update input for vocab_size
LeoRoccoBreedt 46e1f37
refactor: add the vocab_size calculation to the data formatting section
LeoRoccoBreedt ca1b52d
fix: refactor data loading process from HF
LeoRoccoBreedt 836c452
fix: update validation data to use test subset from HF and comment ou…
LeoRoccoBreedt 814d56f
refactor: create a class to manage hooks for tracking gradients and a…
LeoRoccoBreedt 31968e9
fix: gradients logging to Neptune
LeoRoccoBreedt a9bad61
style: remove old model architecture section
LeoRoccoBreedt 32d46f1
refactor: change attribute names for better readability
LeoRoccoBreedt 32a43b1
style: update sections and model architecture
LeoRoccoBreedt 0c8497b
chore: cleanup commented code
LeoRoccoBreedt f61f734
chore: cleanup model outputs
LeoRoccoBreedt bcc344e
refactor: update quotes for dictionary keys since Colab returns an error
LeoRoccoBreedt df89878
style: update intro and information about debugging metrics
LeoRoccoBreedt 23bd58a
refactor: update to be able to run model on GPUs
LeoRoccoBreedt 0dbd341
style: update ending of notebook with follow along
LeoRoccoBreedt 1f12175
feat: add DataParallel support
LeoRoccoBreedt a183bab
chore: update TODO's
LeoRoccoBreedt 7521fa0
chore: remove text example from this branch
LeoRoccoBreedt 47f3876
refactor: update for DDP training
LeoRoccoBreedt a6995f9
fix: update code cells to python from markdown where appropriate
LeoRoccoBreedt 16084ef
fix: sections
LeoRoccoBreedt 0b6a497
feat: add ddp script
LeoRoccoBreedt 93f74c6
chore: remove unused pytorch notebook
LeoRoccoBreedt c8add00
update to calculate validation loss at each step
LeoRoccoBreedt edaeeef
add Neptune logging to example
LeoRoccoBreedt 34a5bed
refactor: adjust how run is set on rank
LeoRoccoBreedt 99a1cd9
chore: update error checking before executing compute to ensure that …
LeoRoccoBreedt 4ea8530
refactor: cleanup model class to not be dependent on parameters dicti…
LeoRoccoBreedt 3fb95c4
refactor: update run initialization and set environment variables in …
LeoRoccoBreedt b470e41
chore: move ddp.py location in repo
LeoRoccoBreedt 72c6ed6
refactor: update where environment variables are setup
LeoRoccoBreedt 7c7b557
chore: accept changes based on pre-commit recommendations
LeoRoccoBreedt 5cd7488
chore: anonymize api token and project details
LeoRoccoBreedt 401894d
chore: accept pre-commit changes
LeoRoccoBreedt 3a769d6
Merge commit 'fc4bc5ee2d6e3297c8611991e60f372ea785d213' into lb/pytor…
LeoRoccoBreedt 1569ee5
refactor: simplify and cleanup example to focus more on DDP that havi…
LeoRoccoBreedt a0c8d19
chore: rename file
LeoRoccoBreedt 2688009
chore: run pre-commit
LeoRoccoBreedt 19cb6e4
chore: add requirements and bash scripts and update GH worflow config…
LeoRoccoBreedt 5d7dc6a
chore: comment out DDP automated tests workflow
LeoRoccoBreedt 9c2eddc
chore: include importlb-metadata package for Python 3.9 tests
LeoRoccoBreedt 9362e4b
chore: restrict python version
LeoRoccoBreedt 43bd373
chore: update bach to gracefully stop execution if macOS or Windows O…
LeoRoccoBreedt 6296a3b
chore: update requirements and bash
LeoRoccoBreedt 8527d5e
test remove DDP testing workflow
LeoRoccoBreedt 57dec4f
chore: update readme
LeoRoccoBreedt fc03e05
chore: ignore ddp training scripts
LeoRoccoBreedt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
|
||
# Data files | ||
/how-to-guides/hpo/**/mnist | ||
/integrations-and-supported-tools/pytorch/data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
neptune-scale | ||
torch | ||
torchvision | ||
importlib-metadata |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
echo "Installing requirements..." | ||
pip install -U -r requirements.txt | ||
|
||
echo "Running train_ddp_single_run.py..." | ||
torchrun --nproc_per_node=2 --nnodes=1 train_ddp_single_run.py |
287 changes: 287 additions & 0 deletions
287
how-to-guides/ddp-training/scripts/train_ddp_single_run.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
# Important: This script can only be run when using multiple GPUS (> 1) | ||
|
||
import os | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.distributed import DistributedSampler | ||
from torchvision import datasets, transforms | ||
|
||
|
||
def create_dataloader_minst( | ||
rank: int, world_size: int, batch_size: int | ||
) -> Tuple[DataLoader, DataLoader]: | ||
""" | ||
Create distributed data loaders for MNIST dataset. | ||
|
||
Args: | ||
rank (int): Process rank | ||
world_size (int): Total number of processes | ||
batch_size (int): Batch size per process | ||
|
||
Returns: | ||
Tuple[DataLoader, DataLoader]: Training and validation data loaders | ||
""" | ||
# Transform to normalize the data and convert it to tensor | ||
transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (0.5,)), # Normalizing the image to range [-1, 1] | ||
] | ||
) | ||
|
||
# Download and load the MNIST dataset | ||
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform) | ||
val_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform) | ||
|
||
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) | ||
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) | ||
|
||
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank) | ||
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) | ||
|
||
return train_loader, val_loader | ||
|
||
|
||
# Simple Convolutional Neural Network model for MNIST | ||
class SimpleCNN(nn.Module): | ||
def __init__(self): | ||
super(SimpleCNN, self).__init__() | ||
self.conv1 = nn.Conv2d( | ||
1, 32, kernel_size=3, padding=1 | ||
) # Input channels = 1 (grayscale images) | ||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | ||
self.fc1 = nn.Linear(64 * 7 * 7, 128) # Flattened size of image after convolution layers | ||
self.fc2 = nn.Linear(128, 10) # 10 output classes for digits 0-9 | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv1(x)) | ||
x = F.max_pool2d(x, 2) # Pooling layer to downsample | ||
x = F.relu(self.conv2(x)) | ||
x = F.max_pool2d(x, 2) | ||
x = x.view(-1, 64 * 7 * 7) # Flatten the tensor for the fully connected layer | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return x | ||
|
||
|
||
# Function to evaluate the model during validation | ||
def evaluate(model, data_loader, criterion, device): | ||
model.eval() # Ensure model is in training mode if tracking gradients | ||
correct_preds = 0 | ||
total_preds = 0 | ||
epoch_loss = 0 | ||
with torch.no_grad(): # Disable gradient tracking during evaluation | ||
for data, target in data_loader: | ||
|
||
data, target = data.to(device), target.to(device) | ||
|
||
# Forward pass (with gradient tracking if specified) | ||
output = model(data) | ||
loss = criterion(output, target) # Correct loss computation | ||
epoch_loss += loss.item() | ||
|
||
# Calculate accuracy | ||
_, predicted = torch.max(output.data, 1) | ||
total_preds += target.size(0) | ||
correct_preds += (predicted == target).sum().item() | ||
|
||
accuracy = correct_preds / total_preds | ||
return epoch_loss / len(data_loader), accuracy | ||
|
||
|
||
## Setup distributed environment | ||
def setup_distributed(rank: int, world_size: int, backend: str) -> None: | ||
""" | ||
Initialize the distributed environment. | ||
|
||
Args: | ||
rank (int): Process rank | ||
world_size (int): Total number of processes | ||
backend (str): Distributed backend to use | ||
|
||
Raises: | ||
RuntimeError: If distributed initialization fails | ||
""" | ||
try: | ||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size) | ||
torch.cuda.set_device(rank) | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to initialize distributed environment: {e}") | ||
|
||
|
||
def train( | ||
rank: int, | ||
model: nn.Module, | ||
params: Dict[str, Any], | ||
train_loader: DataLoader, | ||
val_loader: DataLoader, | ||
run: Optional[Any] = None, | ||
) -> None: | ||
""" | ||
Train the model using distributed data parallel. | ||
|
||
Args: | ||
rank (int): Process rank | ||
model (nn.Module): Model to be trained | ||
params (Dict[str, Any]): Training parameters | ||
train_loader (DataLoader): Training data loader | ||
val_loader (DataLoader): Validation data loader | ||
run (Optional[Any]): Neptune run object for logging | ||
|
||
Raises: | ||
RuntimeError: If training fails | ||
""" | ||
# Instantiate the device, loss function, and optimizer | ||
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") | ||
print(f"Rank {rank} using device: {device}") | ||
|
||
# Move model to device first | ||
model = model.to(device) | ||
|
||
# Then wrap with DDP | ||
model = DDP(model, device_ids=[rank]) | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"]) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
try: | ||
# Training loop | ||
num_epochs = params["epochs"] | ||
step_counter = 0 | ||
for epoch in range(num_epochs): | ||
model.train() | ||
epoch_loss = 0 | ||
correct_preds = 0 | ||
total_preds = 0 | ||
|
||
# Training step | ||
for batch_idx, (data, target) in enumerate(train_loader, 0): | ||
step_counter += 1 | ||
optimizer.zero_grad() | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
loss = criterion(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
epoch_loss += loss.item() | ||
|
||
# Calculate accuracy | ||
_, predicted = torch.max(output.data, 1) | ||
total_preds += target.size(0) | ||
correct_preds += (predicted == target).sum().item() | ||
batch_accuracy = correct_preds / total_preds | ||
|
||
# Validation step per training step | ||
val_loss, val_accuracy = evaluate( | ||
model, val_loader, criterion, device | ||
) # Evaluate after each step | ||
|
||
if rank == 0: | ||
# Log metrics | ||
run.log_metrics( | ||
data={ | ||
"metrics/train/loss": loss.item(), | ||
"metrics/train/accuracy": batch_accuracy, | ||
"metrics/validation/loss": val_loss, | ||
"metrics/validation/accuracy": val_accuracy, | ||
"epoch_value": epoch, | ||
}, | ||
step=step_counter, | ||
) | ||
|
||
dist.barrier() # synchronize processes before moving to next step | ||
|
||
dist.destroy_process_group() | ||
|
||
except Exception as e: | ||
raise RuntimeError(f"Error during training process (Rank {rank}): {e}") | ||
|
||
|
||
def run_ddp(rank: int, world_size: int, params: Dict[str, Any]) -> None: | ||
""" | ||
Run distributed data parallel training. | ||
|
||
Args: | ||
rank (int): Process rank | ||
world_size (int): Total number of processes | ||
params (Dict[str, Any]): Training parameters | ||
|
||
Raises: | ||
RuntimeError: If DDP training fails | ||
""" | ||
try: | ||
setup_distributed(rank, world_size, "nccl") | ||
train_loader, val_loader = create_dataloader_minst(rank, world_size, params["batch_size"]) | ||
|
||
# Initialize Neptune logger only on the main process from rank 0 | ||
if rank == 0: | ||
from uuid import uuid4 | ||
|
||
from neptune_scale import Run | ||
|
||
# Initialize Neptune logger | ||
run = Run(run_id=f"ddp-{uuid4()}", experiment_name="pytorch-ddp-experiment") | ||
|
||
# Log all parameters | ||
run.log_configs( | ||
{ | ||
"config/learning_rate": params["learning_rate"], | ||
"config/batch_size": params["batch_size"], | ||
"config/num_gpus": params["num_gpus"], | ||
"config/n_classes": params["n_classes"], | ||
} | ||
) | ||
|
||
# Add descriptive tags | ||
run.add_tags(tags=["Torch-MINST", "ddp", "single-node", params["optimizer"]]) | ||
|
||
print(f"View experiment charts:\n{run.get_run_url() + '&detailsTab=charts'}") | ||
else: | ||
run = None | ||
|
||
model = SimpleCNN() | ||
train(rank, model, params, train_loader, val_loader, run) | ||
|
||
# Once training is finished, close the Neptune run from the main process | ||
if rank == 0: | ||
run.close() | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to run DDP training: {e}") | ||
finally: | ||
if dist.is_initialized(): | ||
dist.destroy_process_group() | ||
|
||
|
||
# Run DDP | ||
if __name__ == "__main__": | ||
|
||
# Set environment variables for DDP setup | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
|
||
# Set environment variables for Neptune | ||
os.environ["NEPTUNE_PROJECT"] = "your_project_name/your_workspace_name" | ||
os.environ["NEPTUNE_API_TOKEN"] = "your_api_token" | ||
|
||
# Set parameters | ||
params = { | ||
"optimizer": "Adam", | ||
"batch_size": 512, | ||
"learning_rate": 0.01, | ||
"epochs": 5, | ||
"num_gpus": torch.cuda.device_count(), | ||
"n_classes": 10, | ||
} | ||
|
||
# Spawn ddp job to multiple GPU's | ||
print(f"Example will use {params['num_gpus']} GPU's") | ||
mp.set_start_method("spawn", force=True) | ||
mp.spawn(run_ddp, args=(params["num_gpus"], params), nprocs=params["num_gpus"], join=True) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs link for 'DDP training scripts' is empty. Please provide a valid URL or remove the placeholder link to prevent confusion.
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LeoRoccoBreedt - can you make the suggested changes, including the ones hidden? They are valid.