Skip to content

Commit 78fe9d4

Browse files
tetronmr-c
andauthored
Parallel execution uses thread pool to avoid leaking resources. (#1367)
Co-authored-by: Michael R. Crusoe <[email protected]>
1 parent 6b3e50b commit 78fe9d4

File tree

3 files changed

+171
-31
lines changed

3 files changed

+171
-31
lines changed

cwltool/command_line_tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,7 @@ def collect_output(
11451145
r = [] # type: List[CWLOutputType]
11461146
empty_and_optional = False
11471147
debug = _logger.isEnabledFor(logging.DEBUG)
1148+
result: Optional[CWLOutputType] = None
11481149
if "outputBinding" in schema:
11491150
binding = cast(
11501151
MutableMapping[str, Union[bool, str, List[str]]],

cwltool/executors.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""Single and multi-threaded executors."""
33
import datetime
4+
import functools
45
import logging
56
import math
67
import os
@@ -34,6 +35,8 @@
3435
from .utils import CWLObjectType, JobsType
3536
from .workflow import Workflow
3637
from .workflow_job import WorkflowJob, WorkflowJobStep
38+
from .task_queue import TaskQueue
39+
3740

3841
TMPDIR_LOCK = Lock()
3942

@@ -277,7 +280,6 @@ class MultithreadedJobExecutor(JobExecutor):
277280
def __init__(self) -> None:
278281
"""Initialize."""
279282
super(MultithreadedJobExecutor, self).__init__()
280-
self.threads = set() # type: Set[threading.Thread]
281283
self.exceptions = [] # type: List[WorkflowException]
282284
self.pending_jobs = [] # type: List[JobsType]
283285
self.pending_jobs_lock = threading.Lock()
@@ -339,7 +341,6 @@ def _runner(self, job, runtime_context, TMPDIR_LOCK):
339341
finally:
340342
if runtime_context.workflow_eval_lock:
341343
with runtime_context.workflow_eval_lock:
342-
self.threads.remove(threading.current_thread())
343344
if isinstance(job, JobBase):
344345
ram = job.builder.resources["ram"]
345346
if not isinstance(ram, str):
@@ -362,6 +363,10 @@ def run_job(
362363
with self.pending_jobs_lock:
363364
n = 0
364365
while (n + 1) <= len(self.pending_jobs):
366+
# Simple greedy resource allocation strategy. Go
367+
# through pending jobs in the order they were
368+
# generated and add them to the queue only if there
369+
# are resources available.
365370
job = self.pending_jobs[n]
366371
if isinstance(job, JobBase):
367372
ram = job.builder.resources["ram"]
@@ -403,26 +408,24 @@ def run_job(
403408
n += 1
404409
continue
405410

406-
thread = threading.Thread(
407-
target=self._runner, args=(job, runtime_context, TMPDIR_LOCK)
408-
)
409-
thread.daemon = True
410-
self.threads.add(thread)
411411
if isinstance(job, JobBase):
412412
ram = job.builder.resources["ram"]
413413
if not isinstance(ram, str):
414414
self.allocated_ram += ram
415415
cores = job.builder.resources["cores"]
416416
if not isinstance(cores, str):
417417
self.allocated_cores += cores
418-
thread.start()
418+
self.taskqueue.add(
419+
functools.partial(self._runner, job, runtime_context, TMPDIR_LOCK),
420+
runtime_context.workflow_eval_lock,
421+
)
419422
self.pending_jobs.remove(job)
420423

421424
def wait_for_next_completion(self, runtime_context):
422425
# type: (RuntimeContext) -> None
423426
"""Wait for jobs to finish."""
424427
if runtime_context.workflow_eval_lock is not None:
425-
runtime_context.workflow_eval_lock.wait()
428+
runtime_context.workflow_eval_lock.wait(timeout=3)
426429
if self.exceptions:
427430
raise self.exceptions[0]
428431

@@ -434,36 +437,46 @@ def run_jobs(
434437
runtime_context: RuntimeContext,
435438
) -> None:
436439

437-
jobiter = process.job(job_order_object, self.output_callback, runtime_context)
440+
self.taskqueue = TaskQueue(
441+
threading.Lock(), psutil.cpu_count()
442+
) # type: TaskQueue
443+
try:
438444

439-
if runtime_context.workflow_eval_lock is None:
440-
raise WorkflowException(
441-
"runtimeContext.workflow_eval_lock must not be None"
445+
jobiter = process.job(
446+
job_order_object, self.output_callback, runtime_context
442447
)
443448

444-
runtime_context.workflow_eval_lock.acquire()
445-
for job in jobiter:
446-
if job is not None:
447-
if isinstance(job, JobBase):
448-
job.builder = runtime_context.builder or job.builder
449-
if job.outdir is not None:
450-
self.output_dirs.add(job.outdir)
449+
if runtime_context.workflow_eval_lock is None:
450+
raise WorkflowException(
451+
"runtimeContext.workflow_eval_lock must not be None"
452+
)
451453

452-
self.run_job(job, runtime_context)
454+
runtime_context.workflow_eval_lock.acquire()
455+
for job in jobiter:
456+
if job is not None:
457+
if isinstance(job, JobBase):
458+
job.builder = runtime_context.builder or job.builder
459+
if job.outdir is not None:
460+
self.output_dirs.add(job.outdir)
453461

454-
if job is None:
455-
if self.threads:
456-
self.wait_for_next_completion(runtime_context)
457-
else:
458-
logger.error("Workflow cannot make any more progress.")
459-
break
462+
self.run_job(job, runtime_context)
463+
464+
if job is None:
465+
if self.taskqueue.in_flight > 0:
466+
self.wait_for_next_completion(runtime_context)
467+
else:
468+
logger.error("Workflow cannot make any more progress.")
469+
break
460470

461-
self.run_job(None, runtime_context)
462-
while self.threads:
463-
self.wait_for_next_completion(runtime_context)
464471
self.run_job(None, runtime_context)
472+
while self.taskqueue.in_flight > 0:
473+
self.wait_for_next_completion(runtime_context)
474+
self.run_job(None, runtime_context)
465475

466-
runtime_context.workflow_eval_lock.release()
476+
runtime_context.workflow_eval_lock.release()
477+
finally:
478+
self.taskqueue.drain()
479+
self.taskqueue.join()
467480

468481

469482
class NoopJobExecutor(JobExecutor):

cwltool/task_queue.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (C) The Arvados Authors. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
"""TaskQueue."""
5+
6+
import queue
7+
import threading
8+
9+
from typing import Callable, Optional
10+
11+
from .loghandler import _logger
12+
13+
14+
class TaskQueue(object):
15+
"""A TaskQueue class.
16+
17+
Uses a first-in, first-out queue of tasks executed on a fixed number of
18+
threads.
19+
20+
New tasks enter the queue and are started in the order received,
21+
as worker threads become available.
22+
23+
If thread_count == 0 then tasks will be synchronously executed
24+
when add() is called (this makes the actual task queue behavior a
25+
no-op, but may be a useful configuration knob).
26+
27+
The thread_count is also used as the maximum size of the queue.
28+
29+
The threads are created during TaskQueue initialization. Call
30+
join() when you're done with the TaskQueue and want the threads to
31+
stop.
32+
33+
34+
Attributes
35+
----------
36+
in_flight
37+
the number of tasks in the queue
38+
39+
"""
40+
41+
def __init__(self, lock: threading.Lock, thread_count: int):
42+
"""Create a new task queue using the specified lock and number of threads."""
43+
self.thread_count = thread_count
44+
self.task_queue: queue.Queue[Optional[Callable[[], None]]] = queue.Queue(
45+
maxsize=self.thread_count
46+
)
47+
self.task_queue_threads = []
48+
self.lock = lock
49+
self.in_flight = 0
50+
self.error: Optional[BaseException] = None
51+
52+
for _r in range(0, self.thread_count):
53+
t = threading.Thread(target=self._task_queue_func)
54+
self.task_queue_threads.append(t)
55+
t.start()
56+
57+
def _task_queue_func(self) -> None:
58+
while True:
59+
task = self.task_queue.get()
60+
if task is None:
61+
return
62+
try:
63+
task()
64+
except BaseException as e:
65+
_logger.exception("Unhandled exception running task")
66+
self.error = e
67+
finally:
68+
with self.lock:
69+
self.in_flight -= 1
70+
71+
def add(
72+
self,
73+
task: Callable[[], None],
74+
unlock: Optional[threading.Condition] = None,
75+
check_done: Optional[threading.Event] = None,
76+
) -> None:
77+
"""
78+
Add your task to the queue.
79+
80+
The optional unlock will be released prior to attempting to add the
81+
task to the queue.
82+
83+
If the optional "check_done" threading.Event's flag is set, then we
84+
will skip adding this task to the queue.
85+
86+
If the TaskQueue was created with thread_count == 0 then your task will
87+
be synchronously executed.
88+
89+
"""
90+
if self.thread_count == 0:
91+
task()
92+
return
93+
94+
with self.lock:
95+
self.in_flight += 1
96+
97+
while True:
98+
try:
99+
if unlock is not None:
100+
unlock.release()
101+
if check_done is not None and check_done.is_set():
102+
with self.lock:
103+
self.in_flight -= 1
104+
return
105+
self.task_queue.put(task, block=True, timeout=3)
106+
return
107+
except queue.Full:
108+
pass
109+
finally:
110+
if unlock is not None:
111+
unlock.acquire()
112+
113+
def drain(self) -> None:
114+
"""Drain the queue."""
115+
try:
116+
while not self.task_queue.empty():
117+
self.task_queue.get(True, 0.1)
118+
except queue.Empty:
119+
pass
120+
121+
def join(self) -> None:
122+
"""Wait for all threads to complete."""
123+
for _t in self.task_queue_threads:
124+
self.task_queue.put(None)
125+
for t in self.task_queue_threads:
126+
t.join()

0 commit comments

Comments
 (0)