diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index f226058f186fb..cce4c0cf83056 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -69,3 +69,6 @@ RUN python3.9 -m pip install numpy pyarrow 'pandas<=1.5.3' scipy unittest-xml-re # Add Python deps for Spark Connect. RUN python3.9 -m pip install grpcio protobuf googleapis-common-protos grpcio-status + +# Add torch as a testing dependency for TorchDistributor +RUN python3.9 -m pip install torch torchvision diff --git a/dev/requirements.txt b/dev/requirements.txt index 1d978c4602c8f..77a508621fb02 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -59,3 +59,7 @@ googleapis-common-protos==1.56.4 mypy-protobuf==3.3.0 googleapis-common-protos-stubs==2.2.0 grpc-stubs==1.24.11 + +# TorchDistributor dependencies +torch==1.13.1 +torchvision==0.14.1 diff --git a/python/pyspark/ml/torch/tests/test_distributor.py b/python/pyspark/ml/torch/tests/test_distributor.py index 619e733c0bb9d..747229cb9fd28 100644 --- a/python/pyspark/ml/torch/tests/test_distributor.py +++ b/python/pyspark/ml/torch/tests/test_distributor.py @@ -289,6 +289,13 @@ def test_local_training_succeeds(self) -> None: if cuda_env_var: self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var}) + def test_local_file_with_pytorch(self) -> None: + test_file_path = "python/test_support/test_pytorch_training_file.py" + learning_rate_str = "0.01" + TorchDistributor(num_processes=2, local_mode=True, use_gpu=False).run( + test_file_path, learning_rate_str + ) + class TorchDistributorDistributedUnitTests(unittest.TestCase): def setUp(self) -> None: @@ -350,6 +357,13 @@ def test_get_num_tasks_distributed(self) -> None: self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1") + def test_distributed_file_with_pytorch(self) -> None: + test_file_path = "python/test_support/test_pytorch_training_file.py" + learning_rate_str = "0.01" + TorchDistributor(num_processes=2, local_mode=False, use_gpu=False).run( + test_file_path, learning_rate_str + ) + class TorchWrapperUnitTests(unittest.TestCase): def test_clean_and_terminate(self) -> None: diff --git a/python/pyspark/ml/torch/torch_run_process_wrapper.py b/python/pyspark/ml/torch/torch_run_process_wrapper.py index 6b5b6a1d0be4e..7439d09d0c052 100644 --- a/python/pyspark/ml/torch/torch_run_process_wrapper.py +++ b/python/pyspark/ml/torch/torch_run_process_wrapper.py @@ -52,6 +52,8 @@ def check_parent_alive(task: "subprocess.Popen") -> None: cmd = [sys.executable, "-m", "torch.distributed.run", *args] task = subprocess.Popen( cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, stdin=subprocess.PIPE, env=os.environ, ) diff --git a/python/test_support/test_pytorch_training_file.py b/python/test_support/test_pytorch_training_file.py new file mode 100644 index 0000000000000..4107197acfd88 --- /dev/null +++ b/python/test_support/test_pytorch_training_file.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# type: ignore + +batch_size = 100 +num_epochs = 3 +momentum = 0.5 +log_interval = 100 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import tempfile +import shutil + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + +def train_one_epoch(model, data_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(data_loader): + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(data_loader) * len(data), + 100.0 * batch_idx / len(data_loader), + loss.item(), + ) + ) + + +def train(learning_rate): + import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel as DDP + from torch.utils.data.distributed import DistributedSampler + + print("Running distributed training") + dist.init_process_group("gloo") + + temp_dir = tempfile.mkdtemp() + + train_dataset = datasets.MNIST( + temp_dir, + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + + train_sampler = DistributedSampler(dataset=train_dataset) + data_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, sampler=train_sampler + ) + + model = Net() + ddp_model = DDP(model) + + optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=momentum) + for epoch in range(1, num_epochs + 1): + train_one_epoch(ddp_model, data_loader, optimizer, epoch) + + dist.destroy_process_group() + + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("lr", help="learning_rate", default=0.001) + args = parser.parse_args() + print("learning rate chosen: ", float(args.lr)) + train(float(args.lr))