Skip to content

Commit

Permalink
[SPARK-41777][PYSPARK][ML] Integration testing for TorchDistributor
Browse files Browse the repository at this point in the history
Just view the latest commit in this PR for the most accurate diff.

### What changes were proposed in this pull request?

Added integration tests for running distributed training on files.

### Why are the changes needed?

N/A

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

N/A

Closes apache#39637 from rithwik-db/integration-testing.

Authored-by: Rithwik Ediga Lakhamsani <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
rithwik-db authored and HyukjinKwon committed Jan 21, 2023
1 parent e38a1b7 commit e0b09a1
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 0 deletions.
3 changes: 3 additions & 0 deletions dev/infra/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ RUN python3.9 -m pip install numpy pyarrow 'pandas<=1.5.3' scipy unittest-xml-re

# Add Python deps for Spark Connect.
RUN python3.9 -m pip install grpcio protobuf googleapis-common-protos grpcio-status

# Add torch as a testing dependency for TorchDistributor
RUN python3.9 -m pip install torch torchvision
4 changes: 4 additions & 0 deletions dev/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ googleapis-common-protos==1.56.4
mypy-protobuf==3.3.0
googleapis-common-protos-stubs==2.2.0
grpc-stubs==1.24.11

# TorchDistributor dependencies
torch==1.13.1
torchvision==0.14.1
14 changes: 14 additions & 0 deletions python/pyspark/ml/torch/tests/test_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ def test_local_training_succeeds(self) -> None:
if cuda_env_var:
self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})

def test_local_file_with_pytorch(self) -> None:
test_file_path = "python/test_support/test_pytorch_training_file.py"
learning_rate_str = "0.01"
TorchDistributor(num_processes=2, local_mode=True, use_gpu=False).run(
test_file_path, learning_rate_str
)


class TorchDistributorDistributedUnitTests(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -350,6 +357,13 @@ def test_get_num_tasks_distributed(self) -> None:

self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1")

def test_distributed_file_with_pytorch(self) -> None:
test_file_path = "python/test_support/test_pytorch_training_file.py"
learning_rate_str = "0.01"
TorchDistributor(num_processes=2, local_mode=False, use_gpu=False).run(
test_file_path, learning_rate_str
)


class TorchWrapperUnitTests(unittest.TestCase):
def test_clean_and_terminate(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/ml/torch/torch_run_process_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def check_parent_alive(task: "subprocess.Popen") -> None:
cmd = [sys.executable, "-m", "torch.distributed.run", *args]
task = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
env=os.environ,
)
Expand Down
115 changes: 115 additions & 0 deletions python/test_support/test_pytorch_training_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# type: ignore

batch_size = 100
num_epochs = 3
momentum = 0.5
log_interval = 100

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import tempfile
import shutil


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)


def train_one_epoch(model, data_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(data_loader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(data_loader) * len(data),
100.0 * batch_idx / len(data_loader),
loss.item(),
)
)


def train(learning_rate):
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

print("Running distributed training")
dist.init_process_group("gloo")

temp_dir = tempfile.mkdtemp()

train_dataset = datasets.MNIST(
temp_dir,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)

train_sampler = DistributedSampler(dataset=train_dataset)
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler
)

model = Net()
ddp_model = DDP(model)

optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=momentum)
for epoch in range(1, num_epochs + 1):
train_one_epoch(ddp_model, data_loader, optimizer, epoch)

dist.destroy_process_group()

shutil.rmtree(temp_dir)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("lr", help="learning_rate", default=0.001)
args = parser.parse_args()
print("learning rate chosen: ", float(args.lr))
train(float(args.lr))

0 comments on commit e0b09a1

Please sign in to comment.