Skip to content

fix: add lint workflow #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Linting

on:
push:
paths-ignore:
- '*.md'

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install ruff black isort
- name: Run black
run: |
black --check .
- name: Run ruff
run: |
ruff check .
- name: Run isort
run: |
isort **/*.py -c -v
6 changes: 4 additions & 2 deletions cifar100-resnet50/prepare_cifar_data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from torchvision import datasets


def main():

ROOT = './data/cifar100'
ROOT = "./data/cifar100"
train_dataset = datasets.CIFAR100(root=ROOT, train=True, download=True)

# Hold-out this data for final evaluation
valid_dataset = datasets.CIFAR100(root=ROOT, train=False, download=True)

if __name__ == '__main__':

if __name__ == "__main__":

main()
149 changes: 94 additions & 55 deletions cifar100-resnet50/train.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,57 @@
import os, sys, time, warnings, pickle, random
from pathlib import Path
import argparse
import numpy as np
import os
import random
import time
import warnings
from pathlib import Path

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch.distributed as dist

from torch.nn.parallel import DistributedDataParallel as DDP
from cycling_utils import (
AtomicDirectory,
InterruptableDistributedSampler,
atomic_torch_save,
)
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
from torchvision import models, datasets, transforms
from cycling_utils import InterruptableDistributedSampler, AtomicDirectory, atomic_torch_save
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision import datasets, models, transforms

warnings.filterwarnings("ignore")


def topk_accuracy(preds, targs, topk=1, normalize=True):
topk_preds = preds.argsort(axis=1, descending=True)[:,:topk]
topk_accurate = np.array([[t in p] for t,p in zip(targs,topk_preds)])
topk_preds = preds.argsort(axis=1, descending=True)[:, :topk]
topk_accurate = np.array([[t in p] for t, p in zip(targs, topk_preds)])
if normalize:
return topk_accurate.sum() / len(targs)
else:
return topk_accurate.sum()


def model_builder(model_parameters):
model = models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, model_parameters['output_size'])
model.fc = nn.Linear(num_ftrs, model_parameters["output_size"])
return model

def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch_size, accumulate, train_dataset, valid_dataset, evaluate, learning_rate):

def train_eval_ddp(
device_id,
rank,
world_size,
model_parameters,
nepochs,
batch_size,
accumulate,
train_dataset,
valid_dataset,
evaluate,
learning_rate,
):
# Config cuda rank and model
torch.cuda.empty_cache()
torch.cuda.set_device(device_id)
Expand All @@ -47,24 +68,27 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch
# Use gradient compression to reduce communication
ddp_model.register_comm_hook(None, default.fp16_compress_hook)

loss_function = nn.CrossEntropyLoss(reduction='sum').to(device_id)
loss_function = nn.CrossEntropyLoss(reduction="sum").to(device_id)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5, verbose=True)

# Init train and validation samplers and loaders
train_sampler = InterruptableDistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False, num_workers=6)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False, num_workers=6
)

if evaluate:
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, sampler=valid_sampler, shuffle=False, num_workers=6)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, sampler=valid_sampler, shuffle=False, num_workers=6
)

completed_epochs = 0

# init checkpoint saver
output_directory = os.environ["CHECKPOINT_ARTIFACT_PATH"]
saver = AtomicDirectory(output_directory=output_directory, is_master=rank==0)
saver = AtomicDirectory(output_directory=output_directory, is_master=rank == 0)

latest_symlink_file_path = os.path.join(output_directory, saver.symlink_name)
if os.path.islink(latest_symlink_file_path):
Expand Down Expand Up @@ -104,14 +128,16 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch

train_sampler.advance(len(X_train))

if (i + 1) % accumulate == 0 or (i + 1) == n_train_batches: # Final loop in accumulation cycle, or last batch in dataset
if (i + 1) % accumulate == 0 or (
i + 1
) == n_train_batches: # Final loop in accumulation cycle, or last batch in dataset
z_train = ddp_model(X_train)
loss = loss_function(z_train, y_train)
cumulative_train_loss += loss.item()
train_examples_seen += len(y_train)
loss.backward() # Sync gradients between devices
optimizer.step() # Weight update
optimizer.zero_grad() # Zero grad
loss.backward() # Sync gradients between devices
optimizer.step() # Weight update
optimizer.zero_grad() # Zero grad

if i % 50 == 0:
checkpoint_directory = saver.prepare_checkpoint_directory()
Expand All @@ -125,12 +151,13 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"sampler_state_dict": train_sampler.state_dict(),
}, os.path.join(checkpoint_directory, "checkpoint.pt")
},
os.path.join(checkpoint_directory, "checkpoint.pt"),
)

saver.symlink_latest(checkpoint_directory)

else: # Otherwise only accumulate gradients locally to save time.
else: # Otherwise only accumulate gradients locally to save time.
with ddp_model.no_sync():
z_train = ddp_model(X_train)
loss = loss_function(z_train, y_train)
Expand All @@ -148,7 +175,7 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch
valid_examples_seen = 0.0
top1acc = 0.0
top5acc = 0.0

ddp_model.eval()
with torch.no_grad():
for X_valid, y_valid in valid_loader:
Expand All @@ -167,7 +194,7 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch
vloss = cumulative_valid_loss / valid_examples_seen
top1 = (top1acc / valid_examples_seen) * 100
top5 = (top5acc / valid_examples_seen) * 100

else:
vloss = 0
top1 = 0
Expand All @@ -180,21 +207,17 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch
result = np.stack(outputs).mean(axis=0)
tloss_, vloss_, top1_, top5_, epoch_duration_ = result

results[epoch] = {
"tloss": tloss_,
"vloss": vloss_,
"top1": top1_,
"top5": top5_,
"time": epoch_duration_
}
results[epoch] = {"tloss": tloss_, "vloss": vloss_, "top1": top1_, "top5": top5_, "time": epoch_duration_}

# Learning rate scheduler reducing by factor of 10 when training loss stops reducing. Likely to overfit first.
# Must apply same operation for all devices to ensure optimizers remain in sync.
scheduler.step(tloss_)

# If main rank, save results and report.
if rank == 0:
print(f'EPOCH {epoch}, TLOSS {tloss_:.3f}, VLOSS {vloss_:.3f}, TOP1 {top1_:.2f}, TOP5 {top5_:.2f}, TIME {epoch_duration_:.3f}')
print(
f"EPOCH {epoch}, TLOSS {tloss_:.3f}, VLOSS {vloss_:.3f}, TOP1 {top1_:.2f}, TOP5 {top5_:.2f}, TIME {epoch_duration_:.3f}"
)

checkpoint_directory = saver.prepare_checkpoint_directory()

Expand All @@ -207,7 +230,8 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"sampler_state_dict": train_sampler.state_dict(),
}, os.path.join(checkpoint_directory, "checkpoint.pt")
},
os.path.join(checkpoint_directory, "checkpoint.pt"),
)

saver.symlink_latest(checkpoint_directory)
Expand All @@ -228,35 +252,50 @@ def train_eval_ddp(device_id, rank, world_size, model_parameters, nepochs, batch
print(args)

# setup distributed training
dist.init_process_group(backend='nccl')
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
rank = dist.get_rank()
device_id = rank % torch.cuda.device_count()

# Set up train and validation datasets
norm_stats = ((0.5071, 0.4866, 0.4409),(0.2009, 0.1984, 0.2023)) # CIFAR100 training set normalization constants
norm_stats = ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) # CIFAR100 training set normalization constants
R = 384
train_transform = transforms.Compose([
transforms.AutoAugment(policy = transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(R),
transforms.ToTensor(), # Also standardizes to range [0,1]
transforms.Normalize(*norm_stats),
])

valid_transform = transforms.Compose([
transforms.Resize(R),
transforms.ToTensor(), # Also standardizes to range [0,1]
transforms.Normalize(*norm_stats),
])
train_transform = transforms.Compose(
[
transforms.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(R),
transforms.ToTensor(), # Also standardizes to range [0,1]
transforms.Normalize(*norm_stats),
]
)

valid_transform = transforms.Compose(
[
transforms.Resize(R),
transforms.ToTensor(), # Also standardizes to range [0,1]
transforms.Normalize(*norm_stats),
]
)

data_path = os.path.join("/data", args.dataset_id)
train_dataset = datasets.CIFAR100(root=data_path, train=True, transform=train_transform, download=True)

# Hold-out this data for final evaluation
valid_dataset = datasets.CIFAR100(root=data_path, train=False, transform=valid_transform, download=True)

print(f'Train: {len(train_dataset):,.0f}, Valid: {len(valid_dataset):,.0f}')

train_eval_ddp(device_id, rank, world_size, model_parameters={"output_size": 100}, nepochs=args.epochs, batch_size=args.batch_size, accumulate=1, train_dataset=train_dataset, valid_dataset=valid_dataset, evaluate=True, learning_rate=args.lr)

print(f"Train: {len(train_dataset):,.0f}, Valid: {len(valid_dataset):,.0f}")

train_eval_ddp(
device_id,
rank,
world_size,
model_parameters={"output_size": 100},
nepochs=args.epochs,
batch_size=args.batch_size,
accumulate=1,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
evaluate=True,
learning_rate=args.lr,
)
Loading
Loading