1
- from contextlib import contextmanager
1
+ from __future__ import absolute_import , division , print_function
2
+
2
3
import logging
4
+ import math
3
5
import shlex
4
6
import socket
5
7
import subprocess
6
8
import sys
7
9
import warnings
10
+ from collections import OrderedDict
11
+ from contextlib import contextmanager
8
12
9
13
import dask
10
14
import docrep
11
15
from distributed import LocalCluster
12
16
from distributed .deploy import Cluster
13
- from distributed .utils import (get_ip_interface , ignoring , parse_bytes , tmpfile ,
14
- format_bytes )
17
+ from distributed .diagnostics .plugin import SchedulerPlugin
18
+ from distributed .utils import (
19
+ format_bytes , get_ip_interface , parse_bytes , tmpfile )
15
20
16
21
logger = logging .getLogger (__name__ )
17
22
docstrings = docrep .DocstringProcessor ()
28
33
""" .strip ()
29
34
30
35
36
+ def _job_id_from_worker_name (name ):
37
+ ''' utility to parse the job ID from the worker name
38
+
39
+ template: 'prefix--jobid--suffix'
40
+ '''
41
+ _ , job_id , _ = name .split ('--' )
42
+ return job_id
43
+
44
+
45
+ class JobQueuePlugin (SchedulerPlugin ):
46
+ def __init__ (self ):
47
+ self .pending_jobs = OrderedDict ()
48
+ self .running_jobs = OrderedDict ()
49
+ self .finished_jobs = OrderedDict ()
50
+ self .all_workers = {}
51
+
52
+ def add_worker (self , scheduler , worker = None , name = None , ** kwargs ):
53
+ ''' Run when a new worker enters the cluster'''
54
+ logger .debug ("adding worker %s" % worker )
55
+ w = scheduler .workers [worker ]
56
+ job_id = _job_id_from_worker_name (w .name )
57
+ logger .debug ("job id for new worker: %s" % job_id )
58
+ self .all_workers [worker ] = (w .name , job_id )
59
+
60
+ # if this is the first worker for this job, move job to running
61
+ if job_id not in self .running_jobs :
62
+ logger .debug ("this is a new job" )
63
+ self .running_jobs [job_id ] = self .pending_jobs .pop (job_id )
64
+
65
+ # add worker to dict of workers in this job
66
+ self .running_jobs [job_id ][w .name ] = w
67
+
68
+ def remove_worker (self , scheduler = None , worker = None , ** kwargs ):
69
+ ''' Run when a worker leaves the cluster'''
70
+ logger .debug ("removing worker %s" % worker )
71
+ name , job_id = self .all_workers [worker ]
72
+ logger .debug ("removing worker name (%s) and"
73
+ "job_id (%s)" % (name , job_id ))
74
+
75
+ # remove worker from this job
76
+ del self .running_jobs [job_id ][name ]
77
+
78
+ # once there are no more workers, move this job to finished_jobs
79
+ if not self .running_jobs [job_id ]:
80
+ logger .debug ("that was the last worker for job %s" % job_id )
81
+ self .finished_jobs [job_id ] = self .running_jobs .pop (job_id )
82
+
83
+
31
84
@docstrings .get_sectionsf ('JobQueueCluster' )
32
85
class JobQueueCluster (Cluster ):
33
86
""" Base class to launch Dask Clusters for Job queues
@@ -87,6 +140,8 @@ class JobQueueCluster(Cluster):
87
140
submit_command = None
88
141
cancel_command = None
89
142
scheduler_name = ''
143
+ _adaptive_options = {
144
+ 'worker_key' : lambda ws : _job_id_from_worker_name (ws .name )}
90
145
91
146
def __init__ (self ,
92
147
name = None ,
@@ -155,15 +210,17 @@ def __init__(self,
155
210
156
211
self .local_cluster = LocalCluster (n_workers = 0 , ip = host , ** kwargs )
157
212
158
- # Keep information on process, cores, and memory, for use in subclasses
159
- self . worker_memory = parse_bytes ( memory )
160
-
213
+ # Keep information on process, threads and memory, for use in
214
+ # subclasses
215
+ self . worker_memory = parse_bytes ( memory ) if memory is not None else None
161
216
self .worker_processes = processes
162
217
self .worker_cores = cores
163
218
self .name = name
164
219
165
- self .jobs = dict ()
166
- self .n = 0
220
+ # plugin for tracking job status
221
+ self ._scheduler_plugin = JobQueuePlugin ()
222
+ self .local_cluster .scheduler .add_plugin (self ._scheduler_plugin )
223
+
167
224
self ._adaptive = None
168
225
169
226
self ._env_header = '\n ' .join (env_extra )
@@ -179,47 +236,60 @@ def __init__(self,
179
236
mem = format_bytes (self .worker_memory / self .worker_processes )
180
237
mem = mem .replace (' ' , '' )
181
238
self ._command_template += " --memory-limit %s" % mem
239
+ self ._command_template += " --name %s--${JOB_ID}--" % name
182
240
183
- if name is not None :
184
- self ._command_template += " --name %s" % name
185
- self ._command_template += "-%(n)d" # Keep %(n) to be replaced later
186
241
if death_timeout is not None :
187
242
self ._command_template += " --death-timeout %s" % death_timeout
188
243
if local_directory is not None :
189
244
self ._command_template += " --local-directory %s" % local_directory
190
245
if extra is not None :
191
246
self ._command_template += extra
192
247
248
+ @property
249
+ def pending_jobs (self ):
250
+ """ Jobs pending in the queue """
251
+ return self ._scheduler_plugin .pending_jobs
252
+
253
+ @property
254
+ def running_jobs (self ):
255
+ """ Jobs with currenly active workers """
256
+ return self ._scheduler_plugin .running_jobs
257
+
258
+ @property
259
+ def finished_jobs (self ):
260
+ """ Jobs that have finished """
261
+ return self ._scheduler_plugin .finished_jobs
262
+
193
263
@property
194
264
def worker_threads (self ):
195
265
return int (self .worker_cores / self .worker_processes )
196
266
197
267
def job_script (self ):
198
268
""" Construct a job submission script """
199
- self .n += 1
200
- template = self ._command_template % {'n' : self .n }
201
- return self ._script_template % {'job_header' : self .job_header ,
202
- 'env_header' : self ._env_header ,
203
- 'worker_command' : template }
269
+ pieces = {'job_header' : self .job_header ,
270
+ 'env_header' : self ._env_header ,
271
+ 'worker_command' : self ._command_template }
272
+ return self ._script_template % pieces
204
273
205
274
@contextmanager
206
275
def job_file (self ):
207
276
""" Write job submission script to temporary file """
208
277
with tmpfile (extension = 'sh' ) as fn :
209
278
with open (fn , 'w' ) as f :
279
+ logger .debug ("writing job script: \n %s" % self .job_script ())
210
280
f .write (self .job_script ())
211
281
yield fn
212
282
213
283
def start_workers (self , n = 1 ):
214
284
""" Start workers and point them to our local scheduler """
215
- workers = []
216
- for _ in range (n ):
285
+ logger .debug ('starting %s workers' % n )
286
+ num_jobs = math .ceil (n / self .worker_processes )
287
+ for _ in range (num_jobs ):
217
288
with self .job_file () as fn :
218
289
out = self ._call (shlex .split (self .submit_command ) + [fn ])
219
290
job = self ._job_id_from_submit_output (out .decode ())
220
- self .jobs [self .n ] = job
221
- workers .append (self .n )
222
- return workers
291
+ logger .debug ("started job: %s" % job )
292
+ self .pending_jobs [job ] = {}
223
293
224
294
@property
225
295
def scheduler (self ):
@@ -248,12 +318,12 @@ def _calls(self, cmds):
248
318
Also logs any stderr information
249
319
"""
250
320
logger .debug ("Submitting the following calls to command line" )
321
+ procs = []
251
322
for cmd in cmds :
252
323
logger .debug (' ' .join (cmd ))
253
- procs = [subprocess .Popen (cmd ,
254
- stdout = subprocess .PIPE ,
255
- stderr = subprocess .PIPE )
256
- for cmd in cmds ]
324
+ procs .append (subprocess .Popen (cmd ,
325
+ stdout = subprocess .PIPE ,
326
+ stderr = subprocess .PIPE ))
257
327
258
328
result = []
259
329
for proc in procs :
@@ -269,33 +339,60 @@ def _call(self, cmd):
269
339
270
340
def stop_workers (self , workers ):
271
341
""" Stop a list of workers"""
342
+ logger .debug ("Stopping workers: %s" % workers )
272
343
if not workers :
273
344
return
274
- workers = list (map (int , workers ))
275
- jobs = [self .jobs [w ] for w in workers ]
276
- self ._call ([self .cancel_command ] + list (jobs ))
345
+ jobs = self ._stop_pending_jobs () # stop pending jobs too
277
346
for w in workers :
278
- with ignoring (KeyError ):
279
- del self .jobs [w ]
347
+ if isinstance (w , dict ):
348
+ jobs .append (_job_id_from_worker_name (w ['name' ]))
349
+ else :
350
+ jobs .append (_job_id_from_worker_name (w .name ))
351
+ self .stop_jobs (set (jobs ))
352
+
353
+ def stop_jobs (self , jobs ):
354
+ """ Stop a list of jobs"""
355
+ logger .debug ("Stopping jobs: %s" % jobs )
356
+ if jobs :
357
+ jobs = list (jobs )
358
+ self ._call ([self .cancel_command ] + list (set (jobs )))
280
359
281
360
def scale_up (self , n , ** kwargs ):
282
361
""" Brings total worker count up to ``n`` """
283
- return self .start_workers (n - len (self .jobs ))
362
+ logger .debug ("Scaling up to %d workers." % n )
363
+ active_and_pending = sum ([len (j ) for j in self .running_jobs .values ()])
364
+ active_and_pending += self .worker_processes * len (self .pending_jobs )
365
+ logger .debug ("Found %d active/pending workers." % active_and_pending )
366
+ self .start_workers (n - active_and_pending )
284
367
285
368
def scale_down (self , workers ):
286
369
''' Close the workers with the given addresses '''
287
- if isinstance (workers , dict ):
288
- names = {v ['name' ] for v in workers .values ()}
289
- job_ids = {name .split ('-' )[- 2 ] for name in names }
290
- self .stop_workers (job_ids )
370
+ logger .debug ("Scaling down. Workers: %s" % workers )
371
+ worker_states = []
372
+ for w in workers :
373
+ try :
374
+ # Get the actual WorkerState
375
+ worker_states .append (self .scheduler .workers [w ])
376
+ except KeyError :
377
+ logger .debug ('worker %s is already gone' % w )
378
+ self .stop_workers (worker_states )
291
379
292
380
def __enter__ (self ):
293
381
return self
294
382
295
383
def __exit__ (self , type , value , traceback ):
296
- self .stop_workers (self .jobs )
384
+ jobs = self ._stop_pending_jobs ()
385
+ jobs += list (self .running_jobs .keys ())
386
+ self .stop_jobs (set (jobs ))
297
387
self .local_cluster .__exit__ (type , value , traceback )
298
388
389
+ def _stop_pending_jobs (self ):
390
+ jobs = list (self .pending_jobs .keys ())
391
+ logger .debug ("Stopping pending jobs %s" % jobs )
392
+ for job_id in jobs :
393
+ del self .pending_jobs [job_id ]
394
+ return jobs
395
+
299
396
def _job_id_from_submit_output (self , out ):
300
397
raise NotImplementedError ('_job_id_from_submit_output must be '
301
398
'implemented when JobQueueCluster is '
0 commit comments