Skip to content

Commit a829ace

Browse files
committed
Update pytorch runner tests.
1 parent 6b92889 commit a829ace

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

tests/unitary/with_extras/jobs/test_pytorch_ddp.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
import os
66
import sys
77
import unittest
8-
from unittest import mock
8+
from unittest import SkipTest, mock
9+
910
from ads.jobs import DataScienceJobRun
1011
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
1112
PyTorchDistributedRuntimeHandler as Handler,
1213
)
13-
from ads.jobs.templates import driver_utils as utils
1414
from ads.jobs.templates import driver_pytorch as driver
15+
from ads.jobs.templates import driver_utils as utils
1516

1617

1718
class PyTorchRunnerTest(unittest.TestCase):
@@ -49,6 +50,8 @@ def test_wait_for_host_ip(self):
4950
{"message": f"{driver.LOG_PREFIX_HOST_IP} {self.TEST_HOST_IP}"}
5051
]
5152
runner = self.init_torch_runner()
53+
if not runner.host_job_run:
54+
raise SkipTest("Test is skipped for DTv2.")
5255
self.assertEqual(runner.host_ip, None)
5356
runner.wait_for_host_ip_address()
5457
self.assertEqual(runner.host_ip, self.TEST_HOST_IP)
@@ -147,7 +150,11 @@ def test_touch_file(self, run_command):
147150
runner.touch_file("stop")
148151
commasnds = [call_args.args[0] for call_args in run_command.call_args_list]
149152
self.assertEqual(
150-
commasnds, ["ssh -v 10.0.0.2 'touch stop'", "ssh -v 10.0.0.3 'touch stop'"]
153+
commasnds,
154+
[
155+
"ssh -v -o PasswordAuthentication=no 10.0.0.2 'touch stop'",
156+
"ssh -v -o PasswordAuthentication=no 10.0.0.3 'touch stop'",
157+
],
151158
)
152159

153160

@@ -161,6 +168,8 @@ def init_runner(self):
161168
"ads.jobs.DataScienceJobRun.from_ocid"
162169
) as GetJobRun, mock.patch(
163170
"ads.jobs.templates.driver_utils.JobRunner.run_command"
171+
), mock.patch(
172+
"ads.jobs.templates.driver_pytorch.DeepSpeedRunner._print_host_key"
164173
):
165174
GetHostIP.return_value = self.TEST_IP
166175
GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk")
@@ -186,7 +195,7 @@ def test_run(self, time_cmd, run_command, run_deepspeed):
186195
self.assertTrue(
187196
time_cmd.call_args.kwargs["cmd"].endswith(
188197
"libhostname.so.1 OCI__HOSTNAME=10.0.0.1 "
189-
"accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 "
198+
"accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 --use_deepspeed "
190199
"train.py --data abc"
191200
),
192201
time_cmd.call_args.kwargs["cmd"],
@@ -206,6 +215,7 @@ def test_run(self, time_cmd, run_command, run_deepspeed):
206215
"10.0.0.1",
207216
"--main_process_port",
208217
"29400",
218+
"--use_deepspeed",
209219
"--deepspeed_hostfile=/home/datascience/hostfile",
210220
],
211221
)

0 commit comments

Comments
 (0)