Skip to content

Commit bff8a26

Browse files
authored
Fix scatter over workflows by advancing iteration over all steps (#187)
* Fix scatter over workflows by advancing iteration of all steps of a scatter.
1 parent 2b3d2ec commit bff8a26

File tree

1 file changed

+46
-25
lines changed

1 file changed

+46
-25
lines changed

cwltool/workflow.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,13 @@ def valueFromFunc(k, v): # type: (Any, Any) -> Any
305305
# https://github.com/python/mypy/issues/797
306306
**kwargs)
307307
elif method == "flat_crossproduct":
308-
jobs = flat_crossproduct_scatter(step, inputobj,
309-
scatter,
310-
cast(Callable[[Any], Any],
308+
jobs = cast(Generator,
309+
flat_crossproduct_scatter(step, inputobj,
310+
scatter,
311+
cast(Callable[[Any], Any],
311312
# known bug in mypy
312313
# https://github.com/python/mypy/issues/797
313-
callback), 0, **kwargs)
314+
callback), 0, **kwargs))
314315
else:
315316
_logger.debug(u"[job %s] job input %s", step.name, json.dumps(inputobj, indent=4))
316317
inputobj = postScatterEval(inputobj)
@@ -332,7 +333,7 @@ def run(self, **kwargs):
332333
_logger.debug(u"[%s] workflow starting", self.name)
333334

334335
def job(self, joborder, output_callback, **kwargs):
335-
# type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
336+
# type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator
336337
self.state = {}
337338
self.processStatus = "success"
338339

@@ -405,7 +406,7 @@ def __init__(self, toolpath_object, **kwargs):
405406
# TODO: statically validate data links instead of doing it at runtime.
406407

407408
def job(self, joborder, output_callback, **kwargs):
408-
# type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
409+
# type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator
409410
builder = self._init_job(joborder, **kwargs)
410411
wj = WorkflowJob(self, **kwargs)
411412
yield wj
@@ -577,9 +578,25 @@ def setTotal(self, total): # type: (int) -> None
577578
if self.completed == self.total:
578579
self.output_callback(self.dest, self.processStatus)
579580

581+
def parallel_steps(steps, rc, kwargs): # type: (List[Generator], ReceiveScatterOutput, Dict[str, Any]) -> Generator
582+
while rc.completed < rc.total:
583+
made_progress = False
584+
for step in steps:
585+
if kwargs.get("on_error", "stop") == "stop" and rc.processStatus != "success":
586+
break
587+
for j in step:
588+
if kwargs.get("on_error", "stop") == "stop" and rc.processStatus != "success":
589+
break
590+
if j:
591+
made_progress = True
592+
yield j
593+
else:
594+
break
595+
if not made_progress and rc.completed < rc.total:
596+
yield None
580597

581598
def dotproduct_scatter(process, joborder, scatter_keys, output_callback, **kwargs):
582-
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
599+
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
583600
l = None
584601
for s in scatter_keys:
585602
if l is None:
@@ -593,21 +610,23 @@ def dotproduct_scatter(process, joborder, scatter_keys, output_callback, **kwarg
593610

594611
rc = ReceiveScatterOutput(output_callback, output)
595612

613+
steps = []
596614
for n in range(0, l):
597615
jo = copy.copy(joborder)
598616
for s in scatter_keys:
599617
jo[s] = joborder[s][n]
600618

601619
jo = kwargs["postScatterEval"](jo)
602620

603-
for j in process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs):
604-
yield j
621+
steps.append(process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs))
605622

606623
rc.setTotal(l)
607624

625+
return parallel_steps(steps, rc, kwargs)
626+
608627

609628
def nested_crossproduct_scatter(process, joborder, scatter_keys, output_callback, **kwargs):
610-
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
629+
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
611630
scatter_key = scatter_keys[0]
612631
l = len(joborder[scatter_key])
613632
output = {} # type: Dict[Text,List[Text]]
@@ -616,25 +635,24 @@ def nested_crossproduct_scatter(process, joborder, scatter_keys, output_callback
616635

617636
rc = ReceiveScatterOutput(output_callback, output)
618637

638+
steps = []
619639
for n in range(0, l):
620640
jo = copy.copy(joborder)
621641
jo[scatter_key] = joborder[scatter_key][n]
622642

623643
if len(scatter_keys) == 1:
624644
jo = kwargs["postScatterEval"](jo)
625-
for j in process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs):
626-
yield j
645+
steps.append(process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs))
627646
else:
628-
for j in nested_crossproduct_scatter(process, jo,
647+
steps.append(nested_crossproduct_scatter(process, jo,
629648
scatter_keys[1:], cast( # known bug with mypy
630-
# https://github.com/python/mypy/issues/797
649+
# https://github.com/python/mypy/issues/797g
631650
Callable[[Any], Any],
632-
functools.partial(rc.receive_scatter_output, n)),
633-
**kwargs):
634-
yield j
651+
functools.partial(rc.receive_scatter_output, n)), **kwargs))
635652

636653
rc.setTotal(l)
637654

655+
return parallel_steps(steps, rc, kwargs)
638656

639657
def crossproduct_size(joborder, scatter_keys):
640658
# type: (Dict[Text, Any], List[Text]) -> int
@@ -650,7 +668,7 @@ def crossproduct_size(joborder, scatter_keys):
650668
return sum
651669

652670
def flat_crossproduct_scatter(process, joborder, scatter_keys, output_callback, startindex, **kwargs):
653-
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Generator[WorkflowJob, None, None]
671+
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Union[List[Generator], Generator]
654672
scatter_key = scatter_keys[0]
655673
l = len(joborder[scatter_key])
656674
rc = None # type: ReceiveScatterOutput
@@ -665,20 +683,23 @@ def flat_crossproduct_scatter(process, joborder, scatter_keys, output_callback,
665683
else:
666684
raise Exception("Unhandled code path. Please report this.")
667685

686+
steps = []
668687
put = startindex
669688
for n in range(0, l):
670689
jo = copy.copy(joborder)
671690
jo[scatter_key] = joborder[scatter_key][n]
672691

673692
if len(scatter_keys) == 1:
674693
jo = kwargs["postScatterEval"](jo)
675-
for j in process.job(jo, functools.partial(rc.receive_scatter_output, put), **kwargs):
676-
yield j
694+
steps.append(process.job(jo, functools.partial(rc.receive_scatter_output, put), **kwargs))
677695
put += 1
678696
else:
679-
for j in flat_crossproduct_scatter(process, jo, scatter_keys[1:], rc, put, **kwargs):
680-
if j:
681-
put += 1
682-
yield j
697+
add = flat_crossproduct_scatter(process, jo, scatter_keys[1:], rc, put, **kwargs)
698+
put += len(cast(List[Generator], add))
699+
steps.extend(add)
683700

684-
rc.setTotal(put)
701+
if startindex == 0 and not isinstance(output_callback, ReceiveScatterOutput):
702+
rc.setTotal(put)
703+
return parallel_steps(steps, rc, kwargs)
704+
else:
705+
return steps

0 commit comments

Comments
 (0)