Skip to content

Commit c7b274b

Browse files
committed
check for replaced connections in get_(input|output)_from_conn methods
1 parent 28a2828 commit c7b274b

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

nipype2pydra/workflow.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,12 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
381381
"""
382382
Returns the name of the input field in the workflow for the given node and field
383383
escaped by the prefix of the node if present"""
384+
try:
385+
return self.make_input(
386+
field_name=conn.target_in, node_name=conn.target_name, input_node_only=None
387+
)
388+
except KeyError:
389+
pass
384390
if conn.source_name is None or conn.source_name == self.input_node:
385391
return self.make_input(field_name=conn.source_out)
386392
elif conn.target_name is None:
@@ -395,6 +401,14 @@ def get_output_from_conn(self, conn: ConnectionStatement) -> WorkflowOutput:
395401
"""
396402
Returns the name of the input field in the workflow for the given node and field
397403
escaped by the prefix of the node if present"""
404+
try:
405+
return self.make_output(
406+
field_name=conn.source_out,
407+
node_name=conn.source_name,
408+
output_node_only=None,
409+
)
410+
except KeyError:
411+
pass
398412
if conn.target_name is None or conn.target_name == self.output_node:
399413
return self.make_output(field_name=conn.target_in)
400414
elif conn.source_name is None:
@@ -411,7 +425,7 @@ def make_input(
411425
self,
412426
field_name: str,
413427
node_name: ty.Optional[str] = None,
414-
input_node_only: bool = False,
428+
input_node_only: ty.Optional[bool] = False,
415429
) -> WorkflowInput:
416430
"""
417431
Returns the name of the input field in the workflow for the given node and field
@@ -430,6 +444,8 @@ def make_input(
430444
elif len(matching) == 1:
431445
return matching[0]
432446
else:
447+
if input_node_only is None:
448+
raise KeyError
433449
if node_name is None or node_name == self.input_node:
434450
inpt_name = field_name
435451
elif input_node_only:
@@ -453,7 +469,7 @@ def make_output(
453469
self,
454470
field_name: str,
455471
node_name: ty.Optional[str] = None,
456-
output_node_only: bool = False,
472+
output_node_only: ty.Optional[bool] = False,
457473
) -> WorkflowOutput:
458474
"""
459475
Returns the name of the input field in the workflow for the given node and field
@@ -473,6 +489,8 @@ def make_output(
473489
elif len(matching) == 1:
474490
return matching[0]
475491
else:
492+
if output_node_only is None:
493+
raise KeyError
476494
if node_name is None or node_name == self.output_node:
477495
outpt_name = field_name
478496
elif output_node_only:
@@ -1056,7 +1074,7 @@ def _parse_statements(self, func_body: str) -> ty.Tuple[
10561074
statements = split_source_into_statements(func_body)
10571075

10581076
parsed = []
1059-
outputs = []
1077+
output_names = []
10601078
workflow_init = None
10611079
workflow_init_index = None
10621080
assignments = defaultdict(list)
@@ -1106,13 +1124,13 @@ def _parse_statements(self, func_body: str) -> ty.Tuple[
11061124
for conn_stmt in conn_stmts:
11071125
self._unprocessed_connections.append(conn_stmt)
11081126
if conn_stmt.wf_out:
1127+
output_name = self.get_output_from_conn(conn_stmt).name
1128+
conn_stmt.target_in = output_name
11091129
if conn_stmt.conditional:
11101130
parsed.append(conn_stmt)
1111-
else:
1112-
outpt = self.get_output_from_conn(conn_stmt)
1113-
if outpt not in outputs:
1114-
parsed.append(conn_stmt)
1115-
outputs.append(outpt)
1131+
elif output_name not in output_names:
1132+
parsed.append(conn_stmt)
1133+
output_names.append(output_name)
11161134
elif not conn_stmt.lzouttable:
11171135
parsed.append(conn_stmt)
11181136
parsed_stmt = conn_stmts[-1]

0 commit comments

Comments
 (0)