@@ -305,12 +305,13 @@ def valueFromFunc(k, v): # type: (Any, Any) -> Any
305
305
# https://github.com/python/mypy/issues/797
306
306
** kwargs )
307
307
elif method == "flat_crossproduct" :
308
- jobs = flat_crossproduct_scatter (step , inputobj ,
309
- scatter ,
310
- cast (Callable [[Any ], Any ],
308
+ jobs = cast (Generator ,
309
+ flat_crossproduct_scatter (step , inputobj ,
310
+ scatter ,
311
+ cast (Callable [[Any ], Any ],
311
312
# known bug in mypy
312
313
# https://github.com/python/mypy/issues/797
313
- callback ), 0 , ** kwargs )
314
+ callback ), 0 , ** kwargs ) )
314
315
else :
315
316
_logger .debug (u"[job %s] job input %s" , step .name , json .dumps (inputobj , indent = 4 ))
316
317
inputobj = postScatterEval (inputobj )
@@ -332,7 +333,7 @@ def run(self, **kwargs):
332
333
_logger .debug (u"[%s] workflow starting" , self .name )
333
334
334
335
def job (self , joborder , output_callback , ** kwargs ):
335
- # type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
336
+ # type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator
336
337
self .state = {}
337
338
self .processStatus = "success"
338
339
@@ -405,7 +406,7 @@ def __init__(self, toolpath_object, **kwargs):
405
406
# TODO: statically validate data links instead of doing it at runtime.
406
407
407
408
def job (self , joborder , output_callback , ** kwargs ):
408
- # type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
409
+ # type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator
409
410
builder = self ._init_job (joborder , ** kwargs )
410
411
wj = WorkflowJob (self , ** kwargs )
411
412
yield wj
@@ -577,9 +578,25 @@ def setTotal(self, total): # type: (int) -> None
577
578
if self .completed == self .total :
578
579
self .output_callback (self .dest , self .processStatus )
579
580
581
+ def parallel_steps (steps , rc , kwargs ): # type: (List[Generator], ReceiveScatterOutput, Dict[str, Any]) -> Generator
582
+ while rc .completed < rc .total :
583
+ made_progress = False
584
+ for step in steps :
585
+ if kwargs .get ("on_error" , "stop" ) == "stop" and rc .processStatus != "success" :
586
+ break
587
+ for j in step :
588
+ if kwargs .get ("on_error" , "stop" ) == "stop" and rc .processStatus != "success" :
589
+ break
590
+ if j :
591
+ made_progress = True
592
+ yield j
593
+ else :
594
+ break
595
+ if not made_progress and rc .completed < rc .total :
596
+ yield None
580
597
581
598
def dotproduct_scatter (process , joborder , scatter_keys , output_callback , ** kwargs ):
582
- # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
599
+ # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
583
600
l = None
584
601
for s in scatter_keys :
585
602
if l is None :
@@ -593,21 +610,23 @@ def dotproduct_scatter(process, joborder, scatter_keys, output_callback, **kwarg
593
610
594
611
rc = ReceiveScatterOutput (output_callback , output )
595
612
613
+ steps = []
596
614
for n in range (0 , l ):
597
615
jo = copy .copy (joborder )
598
616
for s in scatter_keys :
599
617
jo [s ] = joborder [s ][n ]
600
618
601
619
jo = kwargs ["postScatterEval" ](jo )
602
620
603
- for j in process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ):
604
- yield j
621
+ steps .append (process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ))
605
622
606
623
rc .setTotal (l )
607
624
625
+ return parallel_steps (steps , rc , kwargs )
626
+
608
627
609
628
def nested_crossproduct_scatter (process , joborder , scatter_keys , output_callback , ** kwargs ):
610
- # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
629
+ # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
611
630
scatter_key = scatter_keys [0 ]
612
631
l = len (joborder [scatter_key ])
613
632
output = {} # type: Dict[Text,List[Text]]
@@ -616,25 +635,24 @@ def nested_crossproduct_scatter(process, joborder, scatter_keys, output_callback
616
635
617
636
rc = ReceiveScatterOutput (output_callback , output )
618
637
638
+ steps = []
619
639
for n in range (0 , l ):
620
640
jo = copy .copy (joborder )
621
641
jo [scatter_key ] = joborder [scatter_key ][n ]
622
642
623
643
if len (scatter_keys ) == 1 :
624
644
jo = kwargs ["postScatterEval" ](jo )
625
- for j in process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ):
626
- yield j
645
+ steps .append (process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ))
627
646
else :
628
- for j in nested_crossproduct_scatter (process , jo ,
647
+ steps . append ( nested_crossproduct_scatter (process , jo ,
629
648
scatter_keys [1 :], cast ( # known bug with mypy
630
- # https://github.com/python/mypy/issues/797
649
+ # https://github.com/python/mypy/issues/797g
631
650
Callable [[Any ], Any ],
632
- functools .partial (rc .receive_scatter_output , n )),
633
- ** kwargs ):
634
- yield j
651
+ functools .partial (rc .receive_scatter_output , n )), ** kwargs ))
635
652
636
653
rc .setTotal (l )
637
654
655
+ return parallel_steps (steps , rc , kwargs )
638
656
639
657
def crossproduct_size (joborder , scatter_keys ):
640
658
# type: (Dict[Text, Any], List[Text]) -> int
@@ -650,7 +668,7 @@ def crossproduct_size(joborder, scatter_keys):
650
668
return sum
651
669
652
670
def flat_crossproduct_scatter (process , joborder , scatter_keys , output_callback , startindex , ** kwargs ):
653
- # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Generator[WorkflowJob, None, None ]
671
+ # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Union[List[Generator], Generator ]
654
672
scatter_key = scatter_keys [0 ]
655
673
l = len (joborder [scatter_key ])
656
674
rc = None # type: ReceiveScatterOutput
@@ -665,20 +683,23 @@ def flat_crossproduct_scatter(process, joborder, scatter_keys, output_callback,
665
683
else :
666
684
raise Exception ("Unhandled code path. Please report this." )
667
685
686
+ steps = []
668
687
put = startindex
669
688
for n in range (0 , l ):
670
689
jo = copy .copy (joborder )
671
690
jo [scatter_key ] = joborder [scatter_key ][n ]
672
691
673
692
if len (scatter_keys ) == 1 :
674
693
jo = kwargs ["postScatterEval" ](jo )
675
- for j in process .job (jo , functools .partial (rc .receive_scatter_output , put ), ** kwargs ):
676
- yield j
694
+ steps .append (process .job (jo , functools .partial (rc .receive_scatter_output , put ), ** kwargs ))
677
695
put += 1
678
696
else :
679
- for j in flat_crossproduct_scatter (process , jo , scatter_keys [1 :], rc , put , ** kwargs ):
680
- if j :
681
- put += 1
682
- yield j
697
+ add = flat_crossproduct_scatter (process , jo , scatter_keys [1 :], rc , put , ** kwargs )
698
+ put += len (cast (List [Generator ], add ))
699
+ steps .extend (add )
683
700
684
- rc .setTotal (put )
701
+ if startindex == 0 and not isinstance (output_callback , ReceiveScatterOutput ):
702
+ rc .setTotal (put )
703
+ return parallel_steps (steps , rc , kwargs )
704
+ else :
705
+ return steps
0 commit comments