8
8
import json
9
9
import logging
10
10
import os
11
+ import tempfile
11
12
import time
12
13
from dataclasses import dataclass , field
13
14
from datetime import datetime
14
- from shutil import copy2 , copytree , rmtree
15
- from tempfile import mkdtemp
16
- from typing import Any , cast , Dict , List , Mapping , Optional , Set , Type # noqa
15
+ from shutil import copy2 , rmtree
16
+ from typing import Any , cast , Dict , Iterable , List , Mapping , Optional , Set , Type # noqa
17
17
18
18
from torchx .schedulers .api import (
19
19
AppDryRunInfo ,
20
20
AppState ,
21
21
DescribeAppResponse ,
22
+ filter_regex ,
22
23
Scheduler ,
23
24
split_lines ,
24
25
Stream ,
25
26
)
26
27
from torchx .schedulers .ids import make_unique
27
28
from torchx .schedulers .ray .ray_common import RayActor , TORCHX_RANK0_HOST
28
29
from torchx .specs import AppDef , macros , NONE , ReplicaStatus , Role , RoleStatus , runopts
30
+ from torchx .workspace .dir_workspace import TmpDirWorkspace
29
31
from typing_extensions import TypedDict
30
32
31
33
@@ -92,7 +94,7 @@ class RayJob:
92
94
dashboard_address:
93
95
The existing dashboard IP address to connect to
94
96
working_dir:
95
- The working directory to copy to the cluster
97
+ The working directory to copy to the cluster
96
98
requirements:
97
99
The libraries to install on the cluster per requirements.txt
98
100
actors:
@@ -102,15 +104,24 @@ class RayJob:
102
104
"""
103
105
104
106
app_id : str
107
+ working_dir : str
105
108
cluster_config_file : Optional [str ] = None
106
109
cluster_name : Optional [str ] = None
107
110
dashboard_address : Optional [str ] = None
108
- working_dir : Optional [str ] = None
109
111
requirements : Optional [str ] = None
110
112
actors : List [RayActor ] = field (default_factory = list )
111
113
112
- class RayScheduler (Scheduler [RayOpts ]):
114
+ class RayScheduler (Scheduler [RayOpts ], TmpDirWorkspace ):
113
115
"""
116
+ RayScheduler is a TorchX scheduling interface to Ray. The job def
117
+ workers will be launched as Ray actors
118
+
119
+ The job environment is specified by the TorchX workspace. Any files in
120
+ the workspace will be present in the Ray job unless specified in
121
+ ``.torchxignore``. Python dependencies will be read from the
122
+ ``requirements.txt`` file located at the root of the workspace unless
123
+ it's overridden via ``-c ...,requirements=foo/requirements.txt``.
124
+
114
125
**Config Options**
115
126
116
127
.. runopts::
@@ -122,12 +133,15 @@ class RayScheduler(Scheduler[RayOpts]):
122
133
type: scheduler
123
134
features:
124
135
cancel: true
125
- logs: true
136
+ logs: |
137
+ Partial support. Ray only supports a single log stream so
138
+ only a dummy "ray/0" combined log role is supported.
139
+ Tailing and time seeking are not supported.
126
140
distributed: true
127
141
describe: |
128
142
Partial support. RayScheduler will return job status but
129
143
does not provide the complete original AppSpec.
130
- workspaces: false
144
+ workspaces: true
131
145
mounts: false
132
146
133
147
"""
@@ -156,11 +170,6 @@ def run_opts(self) -> runopts:
156
170
default = "127.0.0.1:8265" ,
157
171
help = "Use ray status to get the dashboard address you will submit jobs against" ,
158
172
)
159
- opts .add (
160
- "working_dir" ,
161
- type_ = str ,
162
- help = "Copy the the working directory containing the Python scripts to the cluster." ,
163
- )
164
173
opts .add ("requirements" , type_ = str , help = "Path to requirements.txt" )
165
174
return opts
166
175
@@ -169,7 +178,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[RayJob]) -> str:
169
178
170
179
# Create serialized actors for ray_driver.py
171
180
actors = cfg .actors
172
- dirpath = mkdtemp ()
181
+ dirpath = cfg . working_dir
173
182
serialize (actors , dirpath )
174
183
175
184
job_submission_addr : str = ""
@@ -189,41 +198,46 @@ def schedule(self, dryrun_info: AppDryRunInfo[RayJob]) -> str:
189
198
f"http://{ job_submission_addr } "
190
199
)
191
200
192
- # 1. Copy working directory
193
- if cfg .working_dir :
194
- copytree (cfg .working_dir , dirpath , dirs_exist_ok = True )
195
-
196
- # 2. Copy Ray driver utilities
201
+ # 1. Copy Ray driver utilities
197
202
current_directory = os .path .dirname (os .path .abspath (__file__ ))
198
203
copy2 (os .path .join (current_directory , "ray" , "ray_driver.py" ), dirpath )
199
204
copy2 (os .path .join (current_directory , "ray" , "ray_common.py" ), dirpath )
200
205
201
- # 3. Parse requirements.txt
202
- reqs : List [str ] = []
203
- if cfg .requirements : # pragma: no cover
204
- with open (cfg .requirements ) as f :
205
- for line in f :
206
- reqs .append (line .strip ())
206
+ runtime_env = {"working_dir" : dirpath }
207
+ if cfg .requirements :
208
+ runtime_env ["pip" ] = cfg .requirements
207
209
208
- # 4 . Submit Job via the Ray Job Submission API
210
+ # 1 . Submit Job via the Ray Job Submission API
209
211
try :
210
212
job_id : str = client .submit_job (
211
213
job_id = cfg .app_id ,
212
214
# we will pack, hash, zip, upload, register working_dir in GCS of ray cluster
213
215
# and use it to configure your job execution.
214
216
entrypoint = "python3 ray_driver.py" ,
215
- runtime_env = { "working_dir" : dirpath , "pip" : reqs } ,
217
+ runtime_env = runtime_env ,
216
218
)
217
219
218
220
finally :
219
- rmtree (dirpath )
221
+ if dirpath .startswith (tempfile .gettempdir ()):
222
+ rmtree (dirpath )
220
223
221
224
# Encode job submission client in job_id
222
225
return f"{ job_submission_addr } -{ job_id } "
223
226
224
227
def _submit_dryrun (self , app : AppDef , cfg : RayOpts ) -> AppDryRunInfo [RayJob ]:
225
228
app_id = make_unique (app .name )
226
- requirements = cfg .get ("requirements" )
229
+
230
+ working_dir = app .roles [0 ].image
231
+ if not os .path .exists (working_dir ):
232
+ raise RuntimeError (
233
+ f"Role image must be a valid directory, got: { working_dir } "
234
+ )
235
+
236
+ requirements : Optional [str ] = cfg .get ("requirements" )
237
+ if requirements is None :
238
+ workspace_reqs = os .path .join (working_dir , "requirements.txt" )
239
+ if os .path .exists (workspace_reqs ):
240
+ requirements = workspace_reqs
227
241
228
242
cluster_cfg = cfg .get ("cluster_config_file" )
229
243
if cluster_cfg :
@@ -234,8 +248,9 @@ def _submit_dryrun(self, app: AppDef, cfg: RayOpts) -> AppDryRunInfo[RayJob]:
234
248
235
249
job : RayJob = RayJob (
236
250
app_id ,
237
- cluster_cfg ,
251
+ cluster_config_file = cluster_cfg ,
238
252
requirements = requirements ,
253
+ working_dir = working_dir ,
239
254
)
240
255
241
256
else : # pragma: no cover
@@ -244,9 +259,9 @@ def _submit_dryrun(self, app: AppDef, cfg: RayOpts) -> AppDryRunInfo[RayJob]:
244
259
app_id = app_id ,
245
260
dashboard_address = dashboard_address ,
246
261
requirements = requirements ,
262
+ working_dir = working_dir ,
247
263
)
248
264
job .cluster_name = cfg .get ("cluster_name" )
249
- job .working_dir = cfg .get ("working_dir" )
250
265
251
266
for role in app .roles :
252
267
for replica_id in range (role .num_replicas ):
@@ -298,12 +313,10 @@ def wait_until_finish(self, app_id: str, timeout: int = 30) -> None:
298
313
with a given timeout. This is intended for testing. Programmatic
299
314
usage should use the runner wait method instead.
300
315
"""
301
- addr , _ , app_id = app_id .partition ("-" )
302
316
303
- client = JobSubmissionClient (f"http://{ addr } " )
304
317
start = time .time ()
305
318
while time .time () - start <= timeout :
306
- status_info = client . get_job_status (app_id )
319
+ status_info = self . _get_job_status (app_id )
307
320
status = status_info
308
321
if status in {JobStatus .SUCCEEDED , JobStatus .STOPPED , JobStatus .FAILED }:
309
322
break
@@ -314,12 +327,18 @@ def _cancel_existing(self, app_id: str) -> None: # pragma: no cover
314
327
client = JobSubmissionClient (f"http://{ addr } " )
315
328
client .stop_job (app_id )
316
329
317
- def describe (self , app_id : str ) -> Optional [ DescribeAppResponse ] :
330
+ def _get_job_status (self , app_id : str ) -> JobStatus :
318
331
addr , _ , app_id = app_id .partition ("-" )
319
332
client = JobSubmissionClient (f"http://{ addr } " )
320
- job_status_info = client .get_job_status (app_id )
333
+ status = client .get_job_status (app_id )
334
+ if isinstance (status , str ):
335
+ return cast (JobStatus , status )
336
+ return status .status
337
+
338
+ def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
339
+ job_status_info = self ._get_job_status (app_id )
321
340
state = _ray_status_to_torchx_appstate [job_status_info ]
322
- roles = [Role (name = "ray" , num_replicas = - 1 , image = "<N/A>" )]
341
+ roles = [Role (name = "ray" , num_replicas = 1 , image = "<N/A>" )]
323
342
324
343
# get ip_address and put it in hostname
325
344
@@ -354,12 +373,15 @@ def log_iter(
354
373
until : Optional [datetime ] = None ,
355
374
should_tail : bool = False ,
356
375
streams : Optional [Stream ] = None ,
357
- ) -> List [str ]:
358
- # TODO: support regex, tailing, streams etc..
376
+ ) -> Iterable [str ]:
377
+ # TODO: support tailing, streams etc..
359
378
addr , _ , app_id = app_id .partition ("-" )
360
379
client : JobSubmissionClient = JobSubmissionClient (f"http://{ addr } " )
361
380
logs : str = client .get_job_logs (app_id )
362
- return split_lines (logs )
381
+ iterator = split_lines (logs )
382
+ if regex :
383
+ return filter_regex (regex , iterator )
384
+ return iterator
363
385
364
386
def create_scheduler (session_name : str , ** kwargs : Any ) -> RayScheduler :
365
387
if not has_ray (): # pragma: no cover
0 commit comments