1
1
#!/usr/bin/env python
2
2
3
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
import os
6
6
import sys
7
7
import unittest
8
- from unittest import mock
8
+ from unittest import SkipTest , mock
9
+
9
10
from ads .jobs import DataScienceJobRun
10
11
from ads .jobs .builders .infrastructure .dsc_job_runtime import (
11
12
PyTorchDistributedRuntimeHandler as Handler ,
12
13
)
13
- from ads .jobs .templates import driver_utils as utils
14
14
from ads .jobs .templates import driver_pytorch as driver
15
+ from ads .jobs .templates import driver_utils as utils
15
16
16
17
17
18
class PyTorchRunnerTest (unittest .TestCase ):
@@ -49,6 +50,8 @@ def test_wait_for_host_ip(self):
49
50
{"message" : f"{ driver .LOG_PREFIX_HOST_IP } { self .TEST_HOST_IP } " }
50
51
]
51
52
runner = self .init_torch_runner ()
53
+ if not runner .host_job_run :
54
+ raise SkipTest ("Test is skipped for DTv2." )
52
55
self .assertEqual (runner .host_ip , None )
53
56
runner .wait_for_host_ip_address ()
54
57
self .assertEqual (runner .host_ip , self .TEST_HOST_IP )
@@ -147,7 +150,11 @@ def test_touch_file(self, run_command):
147
150
runner .touch_file ("stop" )
148
151
commasnds = [call_args .args [0 ] for call_args in run_command .call_args_list ]
149
152
self .assertEqual (
150
- commasnds , ["ssh -v 10.0.0.2 'touch stop'" , "ssh -v 10.0.0.3 'touch stop'" ]
153
+ commasnds ,
154
+ [
155
+ "ssh -v -o PasswordAuthentication=no 10.0.0.2 'touch stop'" ,
156
+ "ssh -v -o PasswordAuthentication=no 10.0.0.3 'touch stop'" ,
157
+ ],
151
158
)
152
159
153
160
@@ -161,6 +168,8 @@ def init_runner(self):
161
168
"ads.jobs.DataScienceJobRun.from_ocid"
162
169
) as GetJobRun , mock .patch (
163
170
"ads.jobs.templates.driver_utils.JobRunner.run_command"
171
+ ), mock .patch (
172
+ "ads.jobs.templates.driver_pytorch.DeepSpeedRunner._print_host_key"
164
173
):
165
174
GetHostIP .return_value = self .TEST_IP
166
175
GetJobRun .return_value = DataScienceJobRun (id = "ocid.abcdefghijk" )
@@ -186,7 +195,7 @@ def test_run(self, time_cmd, run_command, run_deepspeed):
186
195
self .assertTrue (
187
196
time_cmd .call_args .kwargs ["cmd" ].endswith (
188
197
"libhostname.so.1 OCI__HOSTNAME=10.0.0.1 "
189
- "accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 "
198
+ "accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 --use_deepspeed "
190
199
"train.py --data abc"
191
200
),
192
201
time_cmd .call_args .kwargs ["cmd" ],
@@ -206,6 +215,7 @@ def test_run(self, time_cmd, run_command, run_deepspeed):
206
215
"10.0.0.1" ,
207
216
"--main_process_port" ,
208
217
"29400" ,
218
+ "--use_deepspeed" ,
209
219
"--deepspeed_hostfile=/home/datascience/hostfile" ,
210
220
],
211
221
)
0 commit comments