Skip to content

Commit 749a805

Browse files
committed
implemented trimming of unused inputs
1 parent c7b274b commit 749a805

File tree

2 files changed

+82
-95
lines changed

2 files changed

+82
-95
lines changed

nipype2pydra/statements/workflow_build.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ def wf_out_name(self):
219219
return self.workflow_converter.get_output_from_conn(self).name
220220

221221
def __str__(self):
222-
if not self.include:
222+
if not self.include or (
223+
self.wf_in and not self.workflow_converter.inputs[self.source_out].include
224+
):
223225
return f"{self.indent}pass" if self.conditional else ""
224226
code_str = ""
225227
# Get source lazy-field
@@ -856,7 +858,7 @@ def __str__(self):
856858
+ ", ".join(
857859
f"'{i.name}': {i.type_repr}"
858860
for i in sorted(
859-
self.workflow_converter.inputs.values(), key=attrgetter("name")
861+
self.workflow_converter.used_inputs, key=attrgetter("name")
860862
)
861863
)
862864
+ "}, output_spec={"
@@ -867,7 +869,10 @@ def __str__(self):
867869
)
868870
)
869871
+ "}, "
870-
+ ", ".join(f"{i}={i}" for i in sorted(self.workflow_converter.inputs))
872+
+ ", ".join(
873+
f"{i}={i}"
874+
for i in sorted(j.name for j in self.workflow_converter.used_inputs)
875+
)
871876
+ ")\n\n"
872877
)
873878

nipype2pydra/workflow.py

Lines changed: 74 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import attrs
1414
import yaml
1515
from fileformats.core import from_mime, FileSet, Field
16+
from fileformats.core.exceptions import FormatRecognitionError
1617
from .utils import (
1718
UsedSymbols,
1819
split_source_into_statements,
@@ -51,6 +52,15 @@ def convert_node_prefixes(
5152
return {n: v if v is not None else "" for n, v in nodes_it}
5253

5354

55+
def convert_type(tp: ty.Union[str, type]) -> type:
56+
if not isinstance(tp, str):
57+
return tp
58+
try:
59+
return from_mime(tp)
60+
except FormatRecognitionError:
61+
return eval(tp)
62+
63+
5464
@attrs.define
5565
class WorkflowInterfaceField:
5666

@@ -73,7 +83,7 @@ class WorkflowInterfaceField:
7383
)
7484
type: type = attrs.field(
7585
default=ty.Any,
76-
converter=lambda t: from_mime(t) if isinstance(t, str) else t,
86+
converter=convert_type,
7787
metadata={
7888
"help": "The type of the input/output of the converted workflow",
7989
},
@@ -117,6 +127,8 @@ def type_repr_(t):
117127
return t.primitive.__name__
118128
elif issubclass(t, FileSet):
119129
return t.__name__
130+
elif t.__module__ == "builtins":
131+
return t.__name__
120132
else:
121133
return f"{t.__module__}.{t.__name__}"
122134

@@ -154,6 +166,18 @@ class WorkflowInput(WorkflowInterfaceField):
154166
},
155167
)
156168

169+
include: bool = attrs.field(
170+
default=False,
171+
eq=False,
172+
hash=False,
173+
metadata={
174+
"help": (
175+
"Whether the input is required for the workflow once the unused nodes "
176+
"have been filtered out"
177+
)
178+
},
179+
)
180+
157181
def __hash__(self):
158182
return super().__hash__()
159183

@@ -321,6 +345,9 @@ class WorkflowConverter:
321345
_unprocessed_connections: ty.List[ConnectionStatement] = attrs.field(
322346
factory=list, repr=False
323347
)
348+
used_inputs: ty.Optional[ty.Set[WorkflowInput]] = attrs.field(
349+
default=None, repr=False
350+
)
324351

325352
def __attrs_post_init__(self):
326353
if self.workflow_variable is None:
@@ -383,7 +410,9 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
383410
escaped by the prefix of the node if present"""
384411
try:
385412
return self.make_input(
386-
field_name=conn.target_in, node_name=conn.target_name, input_node_only=None
413+
field_name=conn.source_out,
414+
node_name=conn.source_name,
415+
input_node_only=None,
387416
)
388417
except KeyError:
389418
pass
@@ -394,7 +423,7 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
394423
f"Could not find output corresponding to '{conn.source_out}' input"
395424
)
396425
return self.make_input(
397-
field_name=conn.target_in, node_name=conn.target_name, input_node_only=True
426+
field_name=conn.source_out, node_name=conn.source_name, input_node_only=True
398427
)
399428

400429
def get_output_from_conn(self, conn: ConnectionStatement) -> WorkflowOutput:
@@ -437,7 +466,7 @@ def make_input(
437466
if i.node_name == node_name and i.field == field_name
438467
]
439468
if len(matching) > 1:
440-
raise KeyError(
469+
raise RuntimeError(
441470
f"Multiple inputs found for '{field_name}' field in "
442471
f"'{node_name}' node in '{self.name}' workflow"
443472
)
@@ -481,7 +510,7 @@ def make_output(
481510
if o.node_name == node_name and o.field == field_name
482511
]
483512
if len(matching) > 1:
484-
raise KeyError(
513+
raise RuntimeError(
485514
f"Multiple outputs found for '{field_name}' field in "
486515
f"'{node_name}' node in '{self.name}' workflow: "
487516
+ ", ".join(str(m) for m in matching)
@@ -569,74 +598,6 @@ def add_connection_from_output(self, out_conn: ConnectionStatement):
569598
"""Add a connection to an input of the workflow, adding the input if not present"""
570599
self._add_output_conn(out_conn, "from")
571600

572-
# def _add_input_conn(self, conn: ConnectionStatement, direction: str = "in"):
573-
# """Add an incoming connection to an input of the workflow, adding the input
574-
# if not present"""
575-
# if direction == "in":
576-
# node_name = conn.target_name
577-
# field_name = str(conn.target_in)
578-
# else:
579-
# node_name = conn.source_name
580-
# field_name = str(conn.source_out)
581-
# try:
582-
# inpt = self._input_mapping[(node_name, field_name)]
583-
# except KeyError:
584-
# if node_name == self.input_node:
585-
# inpt = WorkflowInput(
586-
# name=field_name,
587-
# node_name=self.input_node,
588-
# field=field_name,
589-
# )
590-
# elif direction == "in":
591-
# name = conn.source_out
592-
# if conn.source_name != conn.workflow_converter.input_node:
593-
# name = f"{conn.source_name}_{name}"
594-
# inpt = WorkflowInput(
595-
# name=name,
596-
# node_name=self.input_node,
597-
# field=field_name,
598-
# )
599-
# else:
600-
# raise KeyError(
601-
# f"Could not find input corresponding to '{field_name}' field in "
602-
# f"'{conn.target_name}' node in '{self.name}' workflow"
603-
# )
604-
# self._input_mapping[(node_name, field_name)] = inpt
605-
# self.inputs[field_name] = inpt
606-
607-
# inpt.in_conns.append(conn)
608-
609-
# def _add_output_conn(self, conn: ConnectionStatement, direction="in"):
610-
# if direction == "from":
611-
# node_name = conn.source_name
612-
# field_name = str(conn.source_out)
613-
# else:
614-
# node_name = conn.target_name
615-
# field_name = str(conn.target_in)
616-
# try:
617-
# outpt = self._output_mapping[(node_name, field_name)]
618-
# except KeyError:
619-
# if node_name == self.output_node:
620-
# outpt = WorkflowOutput(
621-
# name=field_name,
622-
# node_name=self.output_node,
623-
# field=field_name,
624-
# )
625-
# elif direction == "out":
626-
# outpt = WorkflowOutput(
627-
# name=field_name,
628-
# node_name=self.output_node,
629-
# field=field_name,
630-
# )
631-
# else:
632-
# raise KeyError(
633-
# f"Could not foutd output correspondoutg to '{field_name}' field out "
634-
# f"'{conn.target_name}' node out '{self.name}' workflow"
635-
# )
636-
# self._output_mapping[(node_name, field_name)] = outpt
637-
# self.outputs[field_name] = outpt
638-
# outpt.out_conns.append(conn)
639-
640601
@cached_property
641602
def used_symbols(self) -> UsedSymbols:
642603
return UsedSymbols.find(
@@ -651,13 +612,16 @@ def used_symbols(self) -> UsedSymbols:
651612
translations=self.package.all_import_translations,
652613
)
653614

654-
@cached_property
615+
@property
655616
def used_configs(self) -> ty.List[str]:
656617
return self._converted_code[1]
657618

658-
@cached_property
619+
@property
659620
def converted_code(self) -> ty.List[str]:
660-
return self._converted_code[0]
621+
try:
622+
return self._converted_code[0]
623+
except AttributeError as e:
624+
raise RuntimeError("caught AttributeError") from e
661625

662626
@cached_property
663627
def input_output_imports(self) -> ty.List[ImportStatement]:
@@ -667,10 +631,6 @@ def input_output_imports(self) -> ty.List[ImportStatement]:
667631
stmts.append(ImportStatement.from_object(tp))
668632
return ImportStatement.collate(stmts)
669633

670-
@cached_property
671-
def inline_imports(self) -> ty.List[str]:
672-
return [s for s in self.converted_code if isinstance(s, ImportStatement)]
673-
674634
@cached_property
675635
def func_src(self):
676636
return inspect.getsource(self.nipype_function)
@@ -824,6 +784,10 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]:
824784
the names of the used configs
825785
"""
826786

787+
for nested_workflow in self.nested_workflows.values():
788+
# processing nested workflows first so we know which inputs are required
789+
nested_workflow._converted_code
790+
827791
declaration, func_args, post = extract_args(self.func_src)
828792
return_types = post[1:].split(":", 1)[0] # Get the return type
829793

@@ -846,14 +810,26 @@ def add_nonstd_types(tp):
846810

847811
while conn_stack:
848812
conn = conn_stack.pop()
849-
# Will only be included if connected from inputs to outputs, still coerces to
850-
# false but
851-
conn.include = 0
813+
# Will only be included if connected from inputs to outputs. If included
814+
# from input->output traversal nodes and conns are flagged as include=None,
815+
# because this coerces to False but is differentiable from False when we
816+
# come to do the traversal in the other direction
817+
conn.include = None
852818
if conn.target_name:
853819
sibling_target_nodes = self.nodes[conn.target_name]
820+
exclude = True
854821
for target_node in sibling_target_nodes:
855-
target_node.include = 0
856-
conn_stack.extend(target_node.out_conns)
822+
# Check to see if the input is required, so we can change its include
823+
# flag back to false if not
824+
if (
825+
not isinstance(target_node, AddNestedWorkflowStatement)
826+
or target_node.nested_workflow.inputs[conn.target_in].include
827+
):
828+
target_node.include = None
829+
conn_stack.extend(target_node.out_conns)
830+
exclude = False
831+
if exclude:
832+
conn.include = False
857833

858834
# Walk through the graph backwards from the outputs and trim any unnecessary
859835
# connections
@@ -864,20 +840,26 @@ def add_nonstd_types(tp):
864840

865841
nonstd_types.discard(ty.Any)
866842

843+
self.used_inputs = set()
844+
867845
while conn_stack:
868846
conn = conn_stack.pop()
869-
if (
870-
conn.include == 0
871-
): # if included forward from inputs and backwards from outputs
847+
# if included forward from inputs and backwards from outputs
848+
if conn.include is None:
872849
conn.include = True
850+
else:
851+
continue
873852
if conn.source_name:
874853
sibling_source_nodes = self.nodes[conn.source_name]
875854
for source_node in sibling_source_nodes:
876-
if (
877-
source_node.include == 0
878-
): # if included forward from inputs and backwards from outputs
855+
# if included forward from inputs and backwards from outputs
856+
if source_node.include is None:
879857
source_node.include = True
880858
conn_stack.extend(source_node.in_conns)
859+
else:
860+
inpt = self.inputs[conn.source_out]
861+
inpt.include = True
862+
self.used_inputs.add(inpt)
881863

882864
preamble = ""
883865
statements = copy(self.parsed_statements)
@@ -901,7 +883,7 @@ def add_nonstd_types(tp):
901883
self.package.find_and_replace_config_params(code_str, nested_configs)
902884
)
903885

904-
inputs_sig = [f"{i}=attrs.NOTHING" for i in self.inputs]
886+
inputs_sig = [f"{i.name}=attrs.NOTHING" for i in self.used_inputs]
905887

906888
# construct code string with modified signature
907889
signature = (

0 commit comments

Comments
 (0)