Skip to content

Commit 386a937

Browse files
mahdieh-dsttclose
andcommitted
Modified by Tom
Co-authored-by: Tom Close <[email protected]>
1 parent af35292 commit 386a937

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

example-specs/smriprep.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ splits:
44
first_node: ds_surfs
55
- func_name: segmentation
66
first_node: lta2itk_fwd
7+
ignore_tasks:
8+
- smriprep.interfaces.DerivativesDataSink
9+
- nipype.interfaces.utility.base.IdentityInterface
710
args:
811
debug: false
912
freesurfer: true

nipype2pydra/workflow.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ def __init__(self, spec):
1313
**self._parse_workflow_args(self.spec['args'])
1414
)
1515

16-
def node_connections(self):
16+
def node_connections(self, workflow):
1717
connections = defaultdict(dict)
1818

19-
for edge, props in self.wf._graph.edges.items():
19+
# Get connections from workflow graph
20+
for edge, props in workflow._graph.edges.items():
2021
src_node = edge[0].name
2122
dest_node = edge[1].name
2223
for node_conn in props['connect']:
@@ -29,11 +30,16 @@ def node_connections(self):
2930
else:
3031
src_field = src_field.split('.')[-1]
3132
connections[dest_node][dest_field] = f"{src_node}.lzout.{src_field}"
33+
34+
# Look for connections in nested workflows via recursion
35+
for node in workflow.find_name_of_appropriate_method(): # TODO: find the method that iterates through the nodes of the workflow (but not nested nodes)
36+
if isinstance(node, Workflow): # TODO: find a way to check whether the node is a standard node or a nested workflow
37+
connections.update(self.node_connections(node))
3238
return connections
3339

3440
def generate(self, format_with_black=False):
3541

36-
connections = self.node_connections()
42+
connections = self.node_connections(self.wf)
3743
out_text = ""
3844
for node_name in self.wf.list_node_names():
3945
node = self.wf.get_node(node_name)
@@ -51,15 +57,15 @@ def generate(self, format_with_black=False):
5157
pass
5258
if isinstance(val, str) and '\n' in val:
5359
val = '"""' + val + '""""'
54-
node_args += f",\n {arg}={val}"
60+
node_args += f",\n {arg}={val}"
5561

5662
for arg, val in connections[node.name].items():
57-
node_args += f",\n {arg}={val}"
63+
node_args += f",\n {arg}=wf.{val}"
5864

5965
out_text += f"""
60-
wf.add({task_type}(
61-
name="{node.name}"{node_args}
62-
)"""
66+
wf.add({task_type}(
67+
name="{node.name}"{node_args}
68+
)"""
6369

6470
if format_with_black:
6571
out_text = black.format_file_contents(out_text, fast=False, mode=black.FileMode())

0 commit comments

Comments
 (0)