Skip to content

Commit 2f63f18

Browse files
committed
finally sorted out nested-workflow input/output propagation
1 parent 9667474 commit 2f63f18

File tree

2 files changed

+164
-110
lines changed

2 files changed

+164
-110
lines changed

nipype2pydra/statements/workflow_build.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -177,24 +177,21 @@ def targets(self):
177177

178178
@property
179179
def wf_in(self):
180-
if self.source_name is None:
180+
try:
181+
self.workflow_converter.get_input_from_conn(self)
182+
except KeyError:
183+
return False
184+
else:
181185
return True
182-
for inpt in self.workflow_converter.inputs.values():
183-
if self.target_name == inpt.node_name and str(self.target_in) == inpt.field:
184-
return True
185-
return False
186186

187187
@property
188188
def wf_out(self):
189-
if self.target_name is None:
189+
try:
190+
self.workflow_converter.get_output_from_conn(self)
191+
except KeyError:
192+
return False
193+
else:
190194
return True
191-
for output in self.workflow_converter.outputs.values():
192-
if (
193-
self.source_name == output.node_name
194-
and str(self.source_out) == output.field
195-
):
196-
return True
197-
return False
198195

199196
@cached_property
200197
def conditional(self):
@@ -215,29 +212,11 @@ def workflow_variable(self):
215212

216213
@property
217214
def wf_in_name(self):
218-
if not self.wf_in:
219-
raise ValueError(
220-
f"Cannot get wf_in_name for {self} as it is not a workflow input"
221-
)
222-
if self.source_name is None:
223-
return (
224-
self.source_out
225-
if not isinstance(self.source_out, DynamicField)
226-
else self.source_out.varname
227-
)
228-
return self.workflow_converter.get_input(self.target_in, self.target_name).name
215+
return self.workflow_converter.get_input_from_conn(self).name
229216

230217
@property
231218
def wf_out_name(self):
232-
if not self.wf_out:
233-
raise ValueError(
234-
f"Cannot get wf_out_name for {self} as it is not a workflow output"
235-
)
236-
if self.target_name is None:
237-
return self.target_in
238-
return self.workflow_converter.get_output(
239-
self.source_out, self.source_name
240-
).name
219+
return self.workflow_converter.get_output_from_conn(self).name
241220

242221
def __str__(self):
243222
if not self.include:
@@ -274,7 +253,7 @@ def __str__(self):
274253
# to add an "identity" node to pass it through
275254
intf_name = f"{base_task_name}_identity"
276255
code_str += (
277-
f"{self.indent}@pydra.mark.task\n"
256+
f"\n{self.indent}@pydra.mark.task\n"
278257
f"{self.indent}def {intf_name}({self.wf_in_name}: ty.Any) -> ty.Any:\n"
279258
f"{self.indent} return {self.wf_in_name}\n\n"
280259
f"{self.indent}{self.workflow_variable}.add("
@@ -669,11 +648,29 @@ def add_input_connection(self, conn: ConnectionStatement):
669648
else:
670649
target_in = conn.target_in
671650
target_name = None
672-
if target_name == self.nested_workflow.input_node:
651+
# Check for replacements for the given target field
652+
replacements = [
653+
i
654+
for i in self.nested_workflow.inputs.values()
655+
if any(n == target_name and f == target_in for n, f in i.replaces)
656+
]
657+
if len(replacements) > 1:
658+
raise ValueError(
659+
f"Multiple inputs found for replacements of '{target_in}' "
660+
f"field in '{target_name}' node in '{self.name}' workflow: "
661+
+ ", ".join(str(m) for m in replacements)
662+
)
663+
elif len(replacements) == 1:
664+
nested_input = replacements[0]
673665
target_name = None
674-
nested_input = self.nested_workflow.get_input(
675-
target_in, node_name=target_name, create=True
676-
)
666+
else:
667+
# If no replacements, create an input for the nested workflow
668+
if target_name == self.nested_workflow.input_node:
669+
target_name = None
670+
nested_input = self.nested_workflow.make_input(
671+
target_in,
672+
node_name=target_name,
673+
)
677674
conn.target_in = nested_input.name
678675
super().add_input_connection(conn)
679676
if target_name:
@@ -716,11 +713,26 @@ def add_output_connection(self, conn: ConnectionStatement):
716713
else:
717714
source_out = conn.source_out
718715
source_name = None
719-
if source_name == self.nested_workflow.output_node:
716+
replacements = [
717+
o
718+
for o in self.nested_workflow.outputs.values()
719+
if any(n == source_name and f == source_out for n, f in o.replaces)
720+
]
721+
if len(replacements) > 1:
722+
raise KeyError(
723+
f"Multiple outputs found for replacements of '{source_out}' "
724+
f"field in '{source_name}' node in '{self.name}' workflow: "
725+
+ ", ".join(str(m) for m in replacements)
726+
)
727+
elif len(replacements) == 1:
728+
nested_output = replacements[0]
720729
source_name = None
721-
nested_output = self.nested_workflow.get_output(
722-
source_out, node_name=source_name, create=True
723-
)
730+
else:
731+
if source_name == self.nested_workflow.output_node:
732+
source_name = None
733+
nested_output = self.nested_workflow.make_output(
734+
source_out, node_name=source_name
735+
)
724736
conn.source_out = nested_output.name
725737
super().add_output_connection(conn)
726738
if source_name:
@@ -759,7 +771,7 @@ def __str__(self):
759771
parts = self.attribute.split(".")
760772
nested_node_name = parts[2]
761773
attribute_name = parts[3]
762-
target_in = nested_wf.get_input(attribute_name, nested_node_name).name
774+
target_in = nested_wf.make_input(attribute_name, nested_node_name).name
763775
attribute = ".".join(parts[:2] + [target_in] + parts[4:])
764776
workflow_variable = self.nodes[0].workflow_variable
765777
assert (n.workflow_variable == workflow_variable for n in self.nodes)
@@ -782,6 +794,10 @@ def matches(cls, stmt, node_names: ty.List[str]) -> bool:
782794
return False
783795
return bool(cls.match_re(node_names).match(stmt))
784796

797+
@property
798+
def conditional(self):
799+
return len(self.indent) != 4
800+
785801
@classmethod
786802
def parse(
787803
cls, statement: str, workflow_converter: "WorkflowConverter"

0 commit comments

Comments
 (0)