1
+ from __future__ import annotations
1
2
import json
2
3
import tempfile
3
4
from pathlib import Path
10
11
11
12
12
13
class WorkflowConverter :
13
- # creating the wf
14
+ # creating the wf
14
15
def __init__ (self , spec ):
15
16
self .spec = spec
16
- self .wf = load_class_or_func (self .spec ['function' ])(
17
- ** self ._parse_workflow_args (self .spec ['args' ])
18
- ) #loads the 'function' in smriprep.yaml, and implement the args (creates a dictionary)
19
17
20
- def node_connections (self , workflow ):
18
+ self .wf = load_class_or_func (self .spec ["function" ])(
19
+ ** self ._parse_workflow_args (self .spec ["args" ])
20
+ ) # loads the 'function' in smriprep.yaml, and implement the args (creates a dictionary)
21
+
22
+ def node_connections (self , workflow , functions : dict [str , dict ], wf_inputs : dict [str , str ], wf_outputs : dict [str , str ]):
21
23
connections = defaultdict (dict )
22
24
23
25
# iterates over wf graph, Get connections from workflow graph, store connections in a dictionary
24
26
for edge , props in workflow ._graph .edges .items ():
25
27
src_node = edge [0 ].name
26
28
dest_node = edge [1 ].name
27
- for node_conn in props ['connect' ]:
28
- src_field = node_conn [1 ]
29
- dest_field = node_conn [0 ]
30
- if src_field .startswith ('def' ):
31
- print (f"Not sure how to deal with { src_field } in { src_node } to "
32
- f"{ dest_node } .{ dest_field } " )
33
- continue
29
+ dest_node_fullname = workflow .get_node (dest_node ).fullname
30
+ for node_conn in props ["connect" ]:
31
+ src_field = node_conn [0 ]
32
+ dest_field = node_conn [1 ]
33
+ if src_field .startswith ("def" ):
34
+ functions [dest_node_fullname ][dest_field ] = src_field
34
35
else :
35
- src_field = src_field .split ('.' )[- 1 ]
36
- connections [dest_node ][dest_field ] = f"{ src_node } .lzout.{ src_field } "
36
+ connections [dest_node_fullname ][
37
+ dest_field
38
+ ] = f"{ src_node } .lzout.{ src_field } "
37
39
38
- # Look for connections in nested workflows via recursion
39
- for node in workflow ._get_all_nodes (): #TODO: find the method that iterates through the nodes of the workflow (but not nested nodes)---placeholder: find_name_of_appropriate_method
40
- if isinstance (node , Workflow ): # TODO: find a way to check whether the node is a standard node or a nested workflow
41
- connections .update (self .node_connections (node ))
40
+ for nested_wf in workflow ._nested_workflows_cache :
41
+ connections .update (self .node_connections (nested_wf , functions = functions ))
42
42
return connections
43
43
44
+
44
45
def generate (self , format_with_black = False ):
45
46
46
- connections = self .node_connections (self .wf )
47
+ functions = defaultdict (dict )
48
+ connections = self .node_connections (self .wf , functions = functions )
47
49
out_text = ""
48
50
for node_name in self .wf .list_node_names ():
49
51
node = self .wf .get_node (node_name )
@@ -59,11 +61,11 @@ def generate(self, format_with_black=False):
59
61
val = json .dumps (val )
60
62
except TypeError :
61
63
pass
62
- if isinstance (val , str ) and ' \n ' in val :
64
+ if isinstance (val , str ) and " \n " in val :
63
65
val = '"""' + val + '""""'
64
66
node_args += f",\n { arg } ={ val } "
65
67
66
- for arg , val in connections [node .name ].items ():
68
+ for arg , val in connections [node .fullname ].items ():
67
69
node_args += f",\n { arg } =wf.{ val } "
68
70
69
71
out_text += f"""
@@ -72,16 +74,18 @@ def generate(self, format_with_black=False):
72
74
)"""
73
75
74
76
if format_with_black :
75
- out_text = black .format_file_contents (out_text , fast = False , mode = black .FileMode ())
77
+ out_text = black .format_file_contents (
78
+ out_text , fast = False , mode = black .FileMode ()
79
+ )
76
80
return out_text
77
81
78
82
@classmethod
79
83
def _parse_workflow_args (cls , args ):
80
84
dct = {}
81
85
for name , val in args .items ():
82
- if isinstance (val , dict ) and sorted (val .keys ()) == [' args' , ' type' ]:
83
- val = load_class_or_func (val [' type' ])(
84
- ** cls ._parse_workflow_args (val [' args' ])
86
+ if isinstance (val , dict ) and sorted (val .keys ()) == [" args" , " type" ]:
87
+ val = load_class_or_func (val [" type" ])(
88
+ ** cls ._parse_workflow_args (val [" args" ])
85
89
)
86
90
dct [name ] = val
87
91
return dct
@@ -92,6 +96,7 @@ def save_graph(self, out_path: Path, format: str = "svg", work_dir: Path = None)
92
96
work_dir = Path (work_dir )
93
97
graph_dot_path = work_dir / "wf-graph.dot"
94
98
self .wf .write_hierarchical_dotfile (graph_dot_path )
95
- dot_path = sp .check_output ("which dot" , shell = True ).decode ('utf-8' ).strip ()
96
- sp .check_call (f"{ dot_path } -T{ format } { str (graph_dot_path )} > { str (out_path )} " ,
97
- shell = True )
99
+ dot_path = sp .check_output ("which dot" , shell = True ).decode ("utf-8" ).strip ()
100
+ sp .check_call (
101
+ f"{ dot_path } -T{ format } { str (graph_dot_path )} > { str (out_path )} " , shell = True
102
+ )
0 commit comments