13
13
import attrs
14
14
import yaml
15
15
from fileformats .core import from_mime , FileSet , Field
16
+ from fileformats .core .exceptions import FormatRecognitionError
16
17
from .utils import (
17
18
UsedSymbols ,
18
19
split_source_into_statements ,
@@ -51,6 +52,15 @@ def convert_node_prefixes(
51
52
return {n : v if v is not None else "" for n , v in nodes_it }
52
53
53
54
55
+ def convert_type (tp : ty .Union [str , type ]) -> type :
56
+ if not isinstance (tp , str ):
57
+ return tp
58
+ try :
59
+ return from_mime (tp )
60
+ except FormatRecognitionError :
61
+ return eval (tp )
62
+
63
+
54
64
@attrs .define
55
65
class WorkflowInterfaceField :
56
66
@@ -73,7 +83,7 @@ class WorkflowInterfaceField:
73
83
)
74
84
type : type = attrs .field (
75
85
default = ty .Any ,
76
- converter = lambda t : from_mime ( t ) if isinstance ( t , str ) else t ,
86
+ converter = convert_type ,
77
87
metadata = {
78
88
"help" : "The type of the input/output of the converted workflow" ,
79
89
},
@@ -117,6 +127,8 @@ def type_repr_(t):
117
127
return t .primitive .__name__
118
128
elif issubclass (t , FileSet ):
119
129
return t .__name__
130
+ elif t .__module__ == "builtins" :
131
+ return t .__name__
120
132
else :
121
133
return f"{ t .__module__ } .{ t .__name__ } "
122
134
@@ -154,6 +166,18 @@ class WorkflowInput(WorkflowInterfaceField):
154
166
},
155
167
)
156
168
169
+ include : bool = attrs .field (
170
+ default = False ,
171
+ eq = False ,
172
+ hash = False ,
173
+ metadata = {
174
+ "help" : (
175
+ "Whether the input is required for the workflow once the unused nodes "
176
+ "have been filtered out"
177
+ )
178
+ },
179
+ )
180
+
157
181
def __hash__ (self ):
158
182
return super ().__hash__ ()
159
183
@@ -321,6 +345,9 @@ class WorkflowConverter:
321
345
_unprocessed_connections : ty .List [ConnectionStatement ] = attrs .field (
322
346
factory = list , repr = False
323
347
)
348
+ used_inputs : ty .Optional [ty .Set [WorkflowInput ]] = attrs .field (
349
+ default = None , repr = False
350
+ )
324
351
325
352
def __attrs_post_init__ (self ):
326
353
if self .workflow_variable is None :
@@ -383,7 +410,9 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
383
410
escaped by the prefix of the node if present"""
384
411
try :
385
412
return self .make_input (
386
- field_name = conn .target_in , node_name = conn .target_name , input_node_only = None
413
+ field_name = conn .source_out ,
414
+ node_name = conn .source_name ,
415
+ input_node_only = None ,
387
416
)
388
417
except KeyError :
389
418
pass
@@ -394,7 +423,7 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
394
423
f"Could not find output corresponding to '{ conn .source_out } ' input"
395
424
)
396
425
return self .make_input (
397
- field_name = conn .target_in , node_name = conn .target_name , input_node_only = True
426
+ field_name = conn .source_out , node_name = conn .source_name , input_node_only = True
398
427
)
399
428
400
429
def get_output_from_conn (self , conn : ConnectionStatement ) -> WorkflowOutput :
@@ -437,7 +466,7 @@ def make_input(
437
466
if i .node_name == node_name and i .field == field_name
438
467
]
439
468
if len (matching ) > 1 :
440
- raise KeyError (
469
+ raise RuntimeError (
441
470
f"Multiple inputs found for '{ field_name } ' field in "
442
471
f"'{ node_name } ' node in '{ self .name } ' workflow"
443
472
)
@@ -481,7 +510,7 @@ def make_output(
481
510
if o .node_name == node_name and o .field == field_name
482
511
]
483
512
if len (matching ) > 1 :
484
- raise KeyError (
513
+ raise RuntimeError (
485
514
f"Multiple outputs found for '{ field_name } ' field in "
486
515
f"'{ node_name } ' node in '{ self .name } ' workflow: "
487
516
+ ", " .join (str (m ) for m in matching )
@@ -569,74 +598,6 @@ def add_connection_from_output(self, out_conn: ConnectionStatement):
569
598
"""Add a connection to an input of the workflow, adding the input if not present"""
570
599
self ._add_output_conn (out_conn , "from" )
571
600
572
- # def _add_input_conn(self, conn: ConnectionStatement, direction: str = "in"):
573
- # """Add an incoming connection to an input of the workflow, adding the input
574
- # if not present"""
575
- # if direction == "in":
576
- # node_name = conn.target_name
577
- # field_name = str(conn.target_in)
578
- # else:
579
- # node_name = conn.source_name
580
- # field_name = str(conn.source_out)
581
- # try:
582
- # inpt = self._input_mapping[(node_name, field_name)]
583
- # except KeyError:
584
- # if node_name == self.input_node:
585
- # inpt = WorkflowInput(
586
- # name=field_name,
587
- # node_name=self.input_node,
588
- # field=field_name,
589
- # )
590
- # elif direction == "in":
591
- # name = conn.source_out
592
- # if conn.source_name != conn.workflow_converter.input_node:
593
- # name = f"{conn.source_name}_{name}"
594
- # inpt = WorkflowInput(
595
- # name=name,
596
- # node_name=self.input_node,
597
- # field=field_name,
598
- # )
599
- # else:
600
- # raise KeyError(
601
- # f"Could not find input corresponding to '{field_name}' field in "
602
- # f"'{conn.target_name}' node in '{self.name}' workflow"
603
- # )
604
- # self._input_mapping[(node_name, field_name)] = inpt
605
- # self.inputs[field_name] = inpt
606
-
607
- # inpt.in_conns.append(conn)
608
-
609
- # def _add_output_conn(self, conn: ConnectionStatement, direction="in"):
610
- # if direction == "from":
611
- # node_name = conn.source_name
612
- # field_name = str(conn.source_out)
613
- # else:
614
- # node_name = conn.target_name
615
- # field_name = str(conn.target_in)
616
- # try:
617
- # outpt = self._output_mapping[(node_name, field_name)]
618
- # except KeyError:
619
- # if node_name == self.output_node:
620
- # outpt = WorkflowOutput(
621
- # name=field_name,
622
- # node_name=self.output_node,
623
- # field=field_name,
624
- # )
625
- # elif direction == "out":
626
- # outpt = WorkflowOutput(
627
- # name=field_name,
628
- # node_name=self.output_node,
629
- # field=field_name,
630
- # )
631
- # else:
632
- # raise KeyError(
633
- # f"Could not foutd output correspondoutg to '{field_name}' field out "
634
- # f"'{conn.target_name}' node out '{self.name}' workflow"
635
- # )
636
- # self._output_mapping[(node_name, field_name)] = outpt
637
- # self.outputs[field_name] = outpt
638
- # outpt.out_conns.append(conn)
639
-
640
601
@cached_property
641
602
def used_symbols (self ) -> UsedSymbols :
642
603
return UsedSymbols .find (
@@ -651,13 +612,16 @@ def used_symbols(self) -> UsedSymbols:
651
612
translations = self .package .all_import_translations ,
652
613
)
653
614
654
- @cached_property
615
+ @property
655
616
def used_configs (self ) -> ty .List [str ]:
656
617
return self ._converted_code [1 ]
657
618
658
- @cached_property
619
+ @property
659
620
def converted_code (self ) -> ty .List [str ]:
660
- return self ._converted_code [0 ]
621
+ try :
622
+ return self ._converted_code [0 ]
623
+ except AttributeError as e :
624
+ raise RuntimeError ("caught AttributeError" ) from e
661
625
662
626
@cached_property
663
627
def input_output_imports (self ) -> ty .List [ImportStatement ]:
@@ -667,10 +631,6 @@ def input_output_imports(self) -> ty.List[ImportStatement]:
667
631
stmts .append (ImportStatement .from_object (tp ))
668
632
return ImportStatement .collate (stmts )
669
633
670
- @cached_property
671
- def inline_imports (self ) -> ty .List [str ]:
672
- return [s for s in self .converted_code if isinstance (s , ImportStatement )]
673
-
674
634
@cached_property
675
635
def func_src (self ):
676
636
return inspect .getsource (self .nipype_function )
@@ -824,6 +784,10 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]:
824
784
the names of the used configs
825
785
"""
826
786
787
+ for nested_workflow in self .nested_workflows .values ():
788
+ # processing nested workflows first so we know which inputs are required
789
+ nested_workflow ._converted_code
790
+
827
791
declaration , func_args , post = extract_args (self .func_src )
828
792
return_types = post [1 :].split (":" , 1 )[0 ] # Get the return type
829
793
@@ -846,14 +810,26 @@ def add_nonstd_types(tp):
846
810
847
811
while conn_stack :
848
812
conn = conn_stack .pop ()
849
- # Will only be included if connected from inputs to outputs, still coerces to
850
- # false but
851
- conn .include = 0
813
+ # Will only be included if connected from inputs to outputs. If included
814
+ # from input->output traversal nodes and conns are flagged as include=None,
815
+ # because this coerces to False but is differentiable from False when we
816
+ # come to do the traversal in the other direction
817
+ conn .include = None
852
818
if conn .target_name :
853
819
sibling_target_nodes = self .nodes [conn .target_name ]
820
+ exclude = True
854
821
for target_node in sibling_target_nodes :
855
- target_node .include = 0
856
- conn_stack .extend (target_node .out_conns )
822
+ # Check to see if the input is required, so we can change its include
823
+ # flag back to false if not
824
+ if (
825
+ not isinstance (target_node , AddNestedWorkflowStatement )
826
+ or target_node .nested_workflow .inputs [conn .target_in ].include
827
+ ):
828
+ target_node .include = None
829
+ conn_stack .extend (target_node .out_conns )
830
+ exclude = False
831
+ if exclude :
832
+ conn .include = False
857
833
858
834
# Walk through the graph backwards from the outputs and trim any unnecessary
859
835
# connections
@@ -864,20 +840,26 @@ def add_nonstd_types(tp):
864
840
865
841
nonstd_types .discard (ty .Any )
866
842
843
+ self .used_inputs = set ()
844
+
867
845
while conn_stack :
868
846
conn = conn_stack .pop ()
869
- if (
870
- conn .include == 0
871
- ): # if included forward from inputs and backwards from outputs
847
+ # if included forward from inputs and backwards from outputs
848
+ if conn .include is None :
872
849
conn .include = True
850
+ else :
851
+ continue
873
852
if conn .source_name :
874
853
sibling_source_nodes = self .nodes [conn .source_name ]
875
854
for source_node in sibling_source_nodes :
876
- if (
877
- source_node .include == 0
878
- ): # if included forward from inputs and backwards from outputs
855
+ # if included forward from inputs and backwards from outputs
856
+ if source_node .include is None :
879
857
source_node .include = True
880
858
conn_stack .extend (source_node .in_conns )
859
+ else :
860
+ inpt = self .inputs [conn .source_out ]
861
+ inpt .include = True
862
+ self .used_inputs .add (inpt )
881
863
882
864
preamble = ""
883
865
statements = copy (self .parsed_statements )
@@ -901,7 +883,7 @@ def add_nonstd_types(tp):
901
883
self .package .find_and_replace_config_params (code_str , nested_configs )
902
884
)
903
885
904
- inputs_sig = [f"{ i } =attrs.NOTHING" for i in self .inputs ]
886
+ inputs_sig = [f"{ i . name } =attrs.NOTHING" for i in self .used_inputs ]
905
887
906
888
# construct code string with modified signature
907
889
signature = (
0 commit comments