Skip to content

Commit 9667474

Browse files
committed
debugging input/output mapping
1 parent 28d185b commit 9667474

File tree

3 files changed

+70
-42
lines changed

3 files changed

+70
-42
lines changed

nipype2pydra/statements/workflow_build.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,24 @@ def targets(self):
177177

178178
@property
179179
def wf_in(self):
180-
return self.source_name is None or (
181-
(self.target_name, str(self.target_in))
182-
in self.workflow_converter._input_mapping
183-
)
180+
if self.source_name is None:
181+
return True
182+
for inpt in self.workflow_converter.inputs.values():
183+
if self.target_name == inpt.node_name and str(self.target_in) == inpt.field:
184+
return True
185+
return False
184186

185187
@property
186188
def wf_out(self):
187-
return self.target_name is None or (
188-
(self.source_name, str(self.source_out))
189-
in self.workflow_converter._output_mapping
190-
)
189+
if self.target_name is None:
190+
return True
191+
for output in self.workflow_converter.outputs.values():
192+
if (
193+
self.source_name == output.node_name
194+
and str(self.source_out) == output.field
195+
):
196+
return True
197+
return False
191198

192199
@cached_property
193200
def conditional(self):
@@ -212,24 +219,29 @@ def wf_in_name(self):
212219
raise ValueError(
213220
f"Cannot get wf_in_name for {self} as it is not a workflow input"
214221
)
215-
# source_out_name = (
216-
# self.source_out
217-
# if not isinstance(self.source_out, DynamicField)
218-
# else self.source_out.varname
219-
# )
220-
return self.workflow_converter.get_input(self.source_out, self.source_name).name
222+
if self.source_name is None:
223+
return (
224+
self.source_out
225+
if not isinstance(self.source_out, DynamicField)
226+
else self.source_out.varname
227+
)
228+
return self.workflow_converter.get_input(self.target_in, self.target_name).name
221229

222230
@property
223231
def wf_out_name(self):
224232
if not self.wf_out:
225233
raise ValueError(
226234
f"Cannot get wf_out_name for {self} as it is not a workflow output"
227235
)
228-
return self.workflow_converter.get_output(self.target_in, self.target_name).name
236+
if self.target_name is None:
237+
return self.target_in
238+
return self.workflow_converter.get_output(
239+
self.source_out, self.source_name
240+
).name
229241

230242
def __str__(self):
231243
if not self.include:
232-
return f"{self.indent}pass\n" if self.conditional else ""
244+
return f"{self.indent}pass" if self.conditional else ""
233245
code_str = ""
234246
# Get source lazy-field
235247
if self.wf_in:
@@ -450,7 +462,7 @@ def converted_interface(self):
450462

451463
def __str__(self):
452464
if not self.include:
453-
return f"{self.indent}pass\n" if self.conditional else ""
465+
return f"{self.indent}pass" if self.conditional else ""
454466
args = ["=".join(a) for a in self.arg_name_vals]
455467
conn_args = []
456468
for conn in sorted(self.in_conns, key=attrgetter("target_in")):
@@ -580,7 +592,7 @@ class AddNestedWorkflowStatement(AddNodeStatement):
580592

581593
def __str__(self):
582594
if not self.include:
583-
return f"{self.indent}pass\n" if self.conditional else ""
595+
return f"{self.indent}pass" if self.conditional else ""
584596
if self.nested_workflow:
585597
config_params = [
586598
f"{n}_{c}={n}_{c}" for n, c in self.nested_workflow.used_configs
@@ -659,7 +671,9 @@ def add_input_connection(self, conn: ConnectionStatement):
659671
target_name = None
660672
if target_name == self.nested_workflow.input_node:
661673
target_name = None
662-
nested_input = self.nested_workflow.get_input(target_in, node_name=target_name)
674+
nested_input = self.nested_workflow.get_input(
675+
target_in, node_name=target_name, create=True
676+
)
663677
conn.target_in = nested_input.name
664678
super().add_input_connection(conn)
665679
if target_name:
@@ -705,7 +719,7 @@ def add_output_connection(self, conn: ConnectionStatement):
705719
if source_name == self.nested_workflow.output_node:
706720
source_name = None
707721
nested_output = self.nested_workflow.get_output(
708-
source_out, node_name=source_name
722+
source_out, node_name=source_name, create=True
709723
)
710724
conn.source_out = nested_output.name
711725
super().add_output_connection(conn)
@@ -736,7 +750,7 @@ class NodeAssignmentStatement:
736750

737751
def __str__(self):
738752
if not any(n.include for n in self.nodes):
739-
return ""
753+
return f"{self.indent}pass" if self.conditional else ""
740754
node = self.nodes[0]
741755
node_name = node.name
742756
workflow_variable = self.nodes[0].workflow_variable

nipype2pydra/utils/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def from_named_dicts_converter(
463463
allow_none=False,
464464
) -> ty.Dict[str, T]:
465465
converted = {}
466-
for name, conv in dct.items() or []:
466+
for name, conv in (dct or {}).items():
467467
if isinstance(conv, dict):
468468
conv = klass(name=name, **conv)
469469
converted[name] = conv

nipype2pydra/workflow.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import black.report
1313
import attrs
1414
import yaml
15-
from fileformats.core import from_mime, FileSet
15+
from fileformats.core import from_mime, FileSet, Field
1616
from .utils import (
1717
UsedSymbols,
1818
split_source_into_statements,
@@ -114,6 +114,8 @@ def type_repr_(t):
114114
)
115115
if t in (ty.Any, ty.Union, ty.List, ty.Tuple):
116116
return f"ty.{t.__name__}"
117+
elif issubclass(t, Field):
118+
return t.primative.__name__
117119
elif issubclass(t, FileSet):
118120
return t.__name__
119121
else:
@@ -407,7 +409,7 @@ def exported_outputs(self):
407409
return (o for o in self.outputs.values() if o.export)
408410

409411
def get_input(
410-
self, field_name: str, node_name: ty.Optional[str] = None
412+
self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False
411413
) -> WorkflowInput:
412414
"""
413415
Returns the name of the input field in the workflow for the given node and field
@@ -416,17 +418,21 @@ def get_input(
416418
try:
417419
return self._input_mapping[(node_name, field_name)]
418420
except KeyError:
419-
inpt_name = (
420-
field_name
421-
if node_name is None or node_name == self.input_node
422-
else f"{node_name}_{field_name}"
423-
)
421+
if node_name is None or node_name == self.input_node:
422+
inpt_name = field_name
423+
elif create:
424+
inpt_name = f"{node_name}_{field_name}"
425+
else:
426+
raise KeyError(
427+
f"Unrecognised output corresponding to {node_name}:{field_name} field, "
428+
"set create=True to auto-create"
429+
)
424430
inpt = WorkflowInput(name=inpt_name, field=field_name, node_name=node_name)
425431
self.inputs[inpt_name] = self._input_mapping[(node_name, field_name)] = inpt
426432
return inpt
427433

428434
def get_output(
429-
self, field_name: str, node_name: ty.Optional[str] = None
435+
self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False
430436
) -> WorkflowOutput:
431437
"""
432438
Returns the name of the input field in the workflow for the given node and field
@@ -435,11 +441,15 @@ def get_output(
435441
try:
436442
return self._output_mapping[(node_name, field_name)]
437443
except KeyError:
438-
outpt_name = (
439-
field_name
440-
if node_name is None or node_name == self.input_node
441-
else f"{node_name}_{field_name}"
442-
)
444+
if node_name is None or node_name == self.output_node:
445+
outpt_name = field_name
446+
elif create:
447+
outpt_name = f"{node_name}_{field_name}"
448+
else:
449+
raise KeyError(
450+
f"Unrecognised output corresponding to {node_name}:{field_name} field, "
451+
"set create=True to auto-create"
452+
)
443453
outpt = WorkflowOutput(
444454
name=outpt_name, field=field_name, node_name=node_name
445455
)
@@ -923,11 +933,11 @@ def prepare_connections(self):
923933
for node in nodes:
924934
if isinstance(node, AddNestedWorkflowStatement):
925935
exported_inputs.update(
926-
(i.name, self.get_input(i.name, node_name))
936+
(i.name, self.get_input(i.name, node_name, create=True))
927937
for i in node.nested_workflow.exported_inputs
928938
)
929939
exported_outputs.update(
930-
(o.name, self.get_output(o.name, node_name))
940+
(o.name, self.get_output(o.name, node_name, create=True))
931941
for o in node.nested_workflow.exported_outputs
932942
)
933943
for inpt_name, exp_inpt in exported_inputs:
@@ -957,16 +967,20 @@ def prepare_connections(self):
957967
self.parsed_statements.append(conn_stmt)
958968
while self._unprocessed_connections:
959969
conn = self._unprocessed_connections.pop()
960-
if conn.wf_in:
961-
self.get_input(conn.source_out).out_conns.append(conn)
962-
else:
970+
try:
971+
inpt = self.get_input(conn.source_out, node_name=conn.source_name)
972+
except KeyError:
963973
for src_node in self.nodes[conn.source_name]:
964974
src_node.add_output_connection(conn)
965-
if conn.wf_out:
966-
self.get_output(conn.target_in).in_conns.append(conn)
967975
else:
976+
inpt.out_conns.append(conn)
977+
try:
978+
outpt = self.get_output(conn.target_in, node_name=conn.target_name)
979+
except KeyError:
968980
for tgt_node in self.nodes[conn.target_name]:
969981
tgt_node.add_input_connection(conn)
982+
else:
983+
outpt.in_conns.append(conn)
970984

971985
def _parse_statements(self, func_body: str) -> ty.Tuple[
972986
ty.List[

0 commit comments

Comments
 (0)