Skip to content

Commit 33eadf8

Browse files
d4l3kfacebook-github-bot
authored andcommitted
classy vision example
Summary: WIP This is an example of a simple classy vision model using torchx. Reviewed By: kiukchung Differential Revision: D28498759 fbshipit-source-id: 41e034d9d08c59a661181aaec7fda9f6b895b269
1 parent 040742c commit 33eadf8

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
description: Runs the example lightning_classy_vision app.
3+
arguments:
4+
- name: --image
5+
type: str
6+
help: image to run (e.g. foobar:latest)
7+
- name: --resource
8+
type: str
9+
help: resource spec
10+
default: T1
11+
- name: --output_path
12+
type: str
13+
help: output path for model checkpoints (e.g. file:///foo/bar)
14+
required: true
15+
- name: --load_path
16+
type: str
17+
help: path to load pretrained model from
18+
default: ""
19+
- name: --log_dir
20+
type: str
21+
help: path to save tensorboard logs to
22+
default: "/logs"
23+
"""
24+
25+
import torchx.specs.api as torchx
26+
import torchx.schedulers.fb.resource as resource
27+
28+
container = torchx.Container(image=args.image).require(resources=resource.get(args.resource))
29+
entrypoint = "main"
30+
31+
trainer_role = (
32+
torchx.Role(
33+
name="trainer"
34+
)
35+
.runs(
36+
"main",
37+
"--output_path",
38+
args.output_path,
39+
"--load_path",
40+
args.load_path,
41+
"--log_dir",
42+
args.log_dir,
43+
)
44+
.on(container)
45+
.replicas(1)
46+
)
47+
48+
app = torchx.Application("examples-lightning_classy_vision").of(trainer_role)
49+
export(app)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-unsafe
9+
10+
import argparse
11+
import sys
12+
from typing import List
13+
14+
import pytorch_lightning as pl
15+
import torch
16+
from classy_vision.dataset.classy_dataset import ClassyDataset
17+
from classy_vision.dataset.core.random_image_datasets import (
18+
RandomImageDataset,
19+
SampleType,
20+
)
21+
from pytorch_lightning.callbacks import ModelCheckpoint
22+
from pytorch_lightning.loggers import TensorBoardLogger
23+
from torch.nn import functional as F
24+
from torch.utils.data import DataLoader
25+
from torchvision import transforms
26+
27+
28+
class SyntheticMNIST(ClassyDataset):
29+
def __init__(self, transform):
30+
batchsize_per_replica = 16
31+
shuffle = True
32+
num_samples = 1000
33+
dataset = RandomImageDataset(
34+
crop_size=28,
35+
num_channels=3,
36+
num_samples=num_samples,
37+
num_classes=10,
38+
seed=1234,
39+
sample_type=SampleType.TUPLE,
40+
)
41+
super().__init__(
42+
dataset, batchsize_per_replica, shuffle, transform, num_samples
43+
)
44+
45+
46+
class MNISTModel(pl.LightningModule):
47+
def __init__(self):
48+
super(MNISTModel, self).__init__()
49+
self.l1 = torch.nn.Linear(28 * 28, 10)
50+
51+
def forward(self, x):
52+
return torch.relu(self.l1(x.view(x.size(0), -1)))
53+
54+
def training_step(self, batch, batch_nb):
55+
x, y = batch
56+
loss = F.cross_entropy(self(x), y)
57+
return loss
58+
59+
def configure_optimizers(self):
60+
return torch.optim.Adam(self.parameters(), lr=0.02)
61+
62+
63+
def parse_args(argv: List[str]) -> argparse.Namespace:
64+
parser = argparse.ArgumentParser(
65+
description="pytorch lightning + classy vision TorchX example app"
66+
)
67+
parser.add_argument(
68+
"--epochs", type=int, default=3, help="number of epochs to train"
69+
)
70+
parser.add_argument(
71+
"--batch_size", type=int, default=32, help="batch size to use for traiing"
72+
)
73+
parser.add_argument("--load_path", type=str, help="checkpoint path to load from")
74+
parser.add_argument(
75+
"--output_path",
76+
type=str,
77+
help="path to place checkpoints and model outputs",
78+
required=True,
79+
)
80+
parser.add_argument(
81+
"--log_dir", type=str, help="directory to place the logs", default="/tmp"
82+
)
83+
84+
return parser.parse_args(argv)
85+
86+
87+
def main(argv):
88+
args = parse_args(argv)
89+
90+
# Init our model
91+
mnist_model = MNISTModel()
92+
93+
# Init DataLoader from MNIST Dataset
94+
img_transform = transforms.Compose(
95+
[
96+
transforms.Grayscale(),
97+
transforms.ToTensor(),
98+
]
99+
)
100+
train_ds = SyntheticMNIST(
101+
transform=lambda x: (img_transform(x[0]), x[1]),
102+
)
103+
train_loader = DataLoader(train_ds, batch_size=args.batch_size)
104+
105+
checkpoint_callback = ModelCheckpoint(
106+
monitor="train_loss",
107+
dirpath=args.output_path,
108+
save_last=True,
109+
)
110+
if args.load_path:
111+
print(f"loading checkpoint: {args.load_path}...")
112+
mnist_model.load_from_checkpoint(checkpoint_path=args.load_path)
113+
114+
logger = TensorBoardLogger(save_dir=args.log_dir, version=1, name="lightning_logs")
115+
116+
# Initialize a trainer
117+
trainer = pl.Trainer(
118+
logger=logger,
119+
max_epochs=args.epochs,
120+
callbacks=[checkpoint_callback],
121+
)
122+
123+
# Train the model ⚡
124+
trainer.fit(mnist_model, train_loader)
125+
126+
127+
if __name__ == "__main__":
128+
main(sys.argv[1:])

0 commit comments

Comments
 (0)