Skip to content

Commit 88de13a

Browse files
author
mahdieh-dst
committed
working on handling connections with nested workflows
1 parent 4c764ff commit 88de13a

File tree

2 files changed

+34
-29
lines changed

2 files changed

+34
-29
lines changed

nipype2pydra/workflow.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
import json
23
import tempfile
34
from pathlib import Path
@@ -10,40 +11,41 @@
1011

1112

1213
class WorkflowConverter:
13-
#creating the wf
14+
# creating the wf
1415
def __init__(self, spec):
1516
self.spec = spec
16-
self.wf = load_class_or_func(self.spec['function'])(
17-
**self._parse_workflow_args(self.spec['args'])
18-
) #loads the 'function' in smriprep.yaml, and implement the args (creates a dictionary)
1917

20-
def node_connections(self, workflow):
18+
self.wf = load_class_or_func(self.spec["function"])(
19+
**self._parse_workflow_args(self.spec["args"])
20+
) # loads the 'function' in smriprep.yaml, and implement the args (creates a dictionary)
21+
22+
def node_connections(self, workflow, functions: dict[str, dict], wf_inputs: dict[str, str], wf_outputs: dict[str, str]):
2123
connections = defaultdict(dict)
2224

2325
# iterates over wf graph, Get connections from workflow graph, store connections in a dictionary
2426
for edge, props in workflow._graph.edges.items():
2527
src_node = edge[0].name
2628
dest_node = edge[1].name
27-
for node_conn in props['connect']:
28-
src_field = node_conn[1]
29-
dest_field = node_conn[0]
30-
if src_field.startswith('def'):
31-
print(f"Not sure how to deal with {src_field} in {src_node} to "
32-
f"{dest_node}.{dest_field}")
33-
continue
29+
dest_node_fullname = workflow.get_node(dest_node).fullname
30+
for node_conn in props["connect"]:
31+
src_field = node_conn[0]
32+
dest_field = node_conn[1]
33+
if src_field.startswith("def"):
34+
functions[dest_node_fullname][dest_field] = src_field
3435
else:
35-
src_field = src_field.split('.')[-1]
36-
connections[dest_node][dest_field] = f"{src_node}.lzout.{src_field}"
36+
connections[dest_node_fullname][
37+
dest_field
38+
] = f"{src_node}.lzout.{src_field}"
3739

38-
# Look for connections in nested workflows via recursion
39-
for node in workflow._get_all_nodes(): #TODO: find the method that iterates through the nodes of the workflow (but not nested nodes)---placeholder: find_name_of_appropriate_method
40-
if isinstance(node, Workflow): # TODO: find a way to check whether the node is a standard node or a nested workflow
41-
connections.update(self.node_connections(node))
40+
for nested_wf in workflow._nested_workflows_cache:
41+
connections.update(self.node_connections(nested_wf, functions=functions))
4242
return connections
4343

44+
4445
def generate(self, format_with_black=False):
4546

46-
connections = self.node_connections(self.wf)
47+
functions = defaultdict(dict)
48+
connections = self.node_connections(self.wf, functions=functions)
4749
out_text = ""
4850
for node_name in self.wf.list_node_names():
4951
node = self.wf.get_node(node_name)
@@ -59,11 +61,11 @@ def generate(self, format_with_black=False):
5961
val = json.dumps(val)
6062
except TypeError:
6163
pass
62-
if isinstance(val, str) and '\n' in val:
64+
if isinstance(val, str) and "\n" in val:
6365
val = '"""' + val + '""""'
6466
node_args += f",\n {arg}={val}"
6567

66-
for arg, val in connections[node.name].items():
68+
for arg, val in connections[node.fullname].items():
6769
node_args += f",\n {arg}=wf.{val}"
6870

6971
out_text += f"""
@@ -72,16 +74,18 @@ def generate(self, format_with_black=False):
7274
)"""
7375

7476
if format_with_black:
75-
out_text = black.format_file_contents(out_text, fast=False, mode=black.FileMode())
77+
out_text = black.format_file_contents(
78+
out_text, fast=False, mode=black.FileMode()
79+
)
7680
return out_text
7781

7882
@classmethod
7983
def _parse_workflow_args(cls, args):
8084
dct = {}
8185
for name, val in args.items():
82-
if isinstance(val, dict) and sorted(val.keys()) == ['args', 'type']:
83-
val = load_class_or_func(val['type'])(
84-
**cls._parse_workflow_args(val['args'])
86+
if isinstance(val, dict) and sorted(val.keys()) == ["args", "type"]:
87+
val = load_class_or_func(val["type"])(
88+
**cls._parse_workflow_args(val["args"])
8589
)
8690
dct[name] = val
8791
return dct
@@ -92,6 +96,7 @@ def save_graph(self, out_path: Path, format: str = "svg", work_dir: Path = None)
9296
work_dir = Path(work_dir)
9397
graph_dot_path = work_dir / "wf-graph.dot"
9498
self.wf.write_hierarchical_dotfile(graph_dot_path)
95-
dot_path = sp.check_output("which dot", shell=True).decode('utf-8').strip()
96-
sp.check_call(f"{dot_path} -T{format} {str(graph_dot_path)} > {str(out_path)}",
97-
shell=True)
99+
dot_path = sp.check_output("which dot", shell=True).decode("utf-8").strip()
100+
sp.check_call(
101+
f"{dot_path} -T{format} {str(graph_dot_path)} > {str(out_path)}", shell=True
102+
)

tests/test_smriprep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_smriprep_conversion(pkg_dir, cli_runner):
1616
workflow,
1717
[
1818
f"{pkg_dir}/example-specs/smriprep.yaml",
19-
f"{pkg_dir}/outputs/smriprep_new.py"
19+
f"{pkg_dir}/outputs/smriprep_new1.py"
2020
]
2121
)
2222

0 commit comments

Comments
 (0)