@@ -381,6 +381,12 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
381
381
"""
382
382
Returns the name of the input field in the workflow for the given node and field
383
383
escaped by the prefix of the node if present"""
384
+ try :
385
+ return self .make_input (
386
+ field_name = conn .target_in , node_name = conn .target_name , input_node_only = None
387
+ )
388
+ except KeyError :
389
+ pass
384
390
if conn .source_name is None or conn .source_name == self .input_node :
385
391
return self .make_input (field_name = conn .source_out )
386
392
elif conn .target_name is None :
@@ -395,6 +401,14 @@ def get_output_from_conn(self, conn: ConnectionStatement) -> WorkflowOutput:
395
401
"""
396
402
Returns the name of the input field in the workflow for the given node and field
397
403
escaped by the prefix of the node if present"""
404
+ try :
405
+ return self .make_output (
406
+ field_name = conn .source_out ,
407
+ node_name = conn .source_name ,
408
+ output_node_only = None ,
409
+ )
410
+ except KeyError :
411
+ pass
398
412
if conn .target_name is None or conn .target_name == self .output_node :
399
413
return self .make_output (field_name = conn .target_in )
400
414
elif conn .source_name is None :
@@ -411,7 +425,7 @@ def make_input(
411
425
self ,
412
426
field_name : str ,
413
427
node_name : ty .Optional [str ] = None ,
414
- input_node_only : bool = False ,
428
+ input_node_only : ty . Optional [ bool ] = False ,
415
429
) -> WorkflowInput :
416
430
"""
417
431
Returns the name of the input field in the workflow for the given node and field
@@ -430,6 +444,8 @@ def make_input(
430
444
elif len (matching ) == 1 :
431
445
return matching [0 ]
432
446
else :
447
+ if input_node_only is None :
448
+ raise KeyError
433
449
if node_name is None or node_name == self .input_node :
434
450
inpt_name = field_name
435
451
elif input_node_only :
@@ -453,7 +469,7 @@ def make_output(
453
469
self ,
454
470
field_name : str ,
455
471
node_name : ty .Optional [str ] = None ,
456
- output_node_only : bool = False ,
472
+ output_node_only : ty . Optional [ bool ] = False ,
457
473
) -> WorkflowOutput :
458
474
"""
459
475
Returns the name of the input field in the workflow for the given node and field
@@ -473,6 +489,8 @@ def make_output(
473
489
elif len (matching ) == 1 :
474
490
return matching [0 ]
475
491
else :
492
+ if output_node_only is None :
493
+ raise KeyError
476
494
if node_name is None or node_name == self .output_node :
477
495
outpt_name = field_name
478
496
elif output_node_only :
@@ -1056,7 +1074,7 @@ def _parse_statements(self, func_body: str) -> ty.Tuple[
1056
1074
statements = split_source_into_statements (func_body )
1057
1075
1058
1076
parsed = []
1059
- outputs = []
1077
+ output_names = []
1060
1078
workflow_init = None
1061
1079
workflow_init_index = None
1062
1080
assignments = defaultdict (list )
@@ -1106,13 +1124,13 @@ def _parse_statements(self, func_body: str) -> ty.Tuple[
1106
1124
for conn_stmt in conn_stmts :
1107
1125
self ._unprocessed_connections .append (conn_stmt )
1108
1126
if conn_stmt .wf_out :
1127
+ output_name = self .get_output_from_conn (conn_stmt ).name
1128
+ conn_stmt .target_in = output_name
1109
1129
if conn_stmt .conditional :
1110
1130
parsed .append (conn_stmt )
1111
- else :
1112
- outpt = self .get_output_from_conn (conn_stmt )
1113
- if outpt not in outputs :
1114
- parsed .append (conn_stmt )
1115
- outputs .append (outpt )
1131
+ elif output_name not in output_names :
1132
+ parsed .append (conn_stmt )
1133
+ output_names .append (output_name )
1116
1134
elif not conn_stmt .lzouttable :
1117
1135
parsed .append (conn_stmt )
1118
1136
parsed_stmt = conn_stmts [- 1 ]
0 commit comments