Skip to content

Commit b550966

Browse files
committed
handling special super methods
1 parent 68e09c7 commit b550966

File tree

2 files changed

+56
-30
lines changed

2 files changed

+56
-30
lines changed

nipype2pydra/interface/base.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import attrs
1515
from attrs.converters import default_if_none
1616
import nipype.interfaces.base
17-
from nipype.interfaces.base import traits_extension, CommandLine
17+
from nipype.interfaces.base import traits_extension, CommandLine, BaseInterface
1818
from pydra.engine import specs
1919
from pydra.engine.helpers import ensure_list
2020
from ..utils import (
@@ -1019,7 +1019,8 @@ def _misc_cleanups(self, body: str) -> str:
10191019
body,
10201020
flags=re.MULTILINE,
10211021
)
1022-
body = re.sub(r"\w+runtime\.(stdout|stderr)", r"\1", body)
1022+
body = re.sub(r"\bruntime\.(stdout|stderr)", r"\1", body)
1023+
body = re.sub(r"\boutputs\.(\w+)", r"outputs['\1']", body)
10231024
body = body.replace("os.getcwd()", "output_dir")
10241025
return body
10251026

@@ -1307,22 +1308,30 @@ def replace_super(match):
13071308

13081309
return re.sub(r"super\([^\)]*\)\.(\w+)\(([^\)]*)\)", replace_super, method_body)
13091310

1310-
def unwrap_nested_methods(self, method_body, additional_args=()):
1311+
def unwrap_nested_methods(
1312+
self, method_body, additional_args=(), inputs_as_dict: bool = False
1313+
):
13111314
"""
13121315
Converts nested method calls into function calls
13131316
"""
13141317
# Add args to the function signature of method calls
13151318
method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL)
13161319
method_names = [m.__name__ for m in self.referenced_methods]
1317-
unrecognised_methods = set(
1320+
method_body = strip_comments(method_body)
1321+
omitted_methods = {}
1322+
for method_name in set(
13181323
m for m in method_re.findall(method_body) if m not in method_names
1319-
)
1320-
assert (
1321-
not unrecognised_methods
1322-
), f"Found the following unrecognised methods {unrecognised_methods}"
1324+
):
1325+
omitted_methods[method_name] = find_super_method(
1326+
self.nipype_interface, method_name
1327+
)[0]
13231328
splits = method_re.split(method_body)
13241329
new_body = splits[0]
13251330
for name, args in zip(splits[1::2], splits[2::2]):
1331+
if name in omitted_methods:
1332+
new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]]
1333+
new_body += extract_args(args)[-1][1:]
1334+
continue
13261335
# Assign additional return values (which were previously saved to member
13271336
# attributes) to new variables from the method call
13281337
if self.method_returns[name]:
@@ -1342,18 +1351,18 @@ def unwrap_nested_methods(self, method_body, additional_args=()):
13421351
else:
13431352
new_body += ",".join(self.method_returns[name]) + " = "
13441353
else:
1345-
raise NotImplementedError(
1354+
logger.warning(
13461355
"Could not augment the return value of the method converted from "
1347-
"a function with the previously assigned attributes as it is used "
1348-
"directly. Need to replace the method call with a variable and "
1349-
"assign the return value to it on a previous line"
1356+
f"a function '{name}' with the previously assigned attributes "
1357+
f"{self.method_returns[name]} as the method doesn't have a "
1358+
f"singular return statement at the end of the method"
13501359
)
13511360
# Insert additional arguments to the method call (which were previously
13521361
# accessed via member attributes)
13531362
new_body += name + insert_args_in_signature(
13541363
args,
13551364
[
1356-
f"{a}={a}"
1365+
f"{a}=inputs['{a}']" if inputs_as_dict else f"{a}={a}"
13571366
for a in (list(self.method_args[name]) + list(additional_args))
13581367
],
13591368
)
@@ -1368,6 +1377,7 @@ def unwrap_nested_methods(self, method_body, additional_args=()):
13681377
SPECIAL_SUPER_MAPPINGS = {
13691378
CommandLine._list_outputs: "{}",
13701379
CommandLine._format_arg: "argstr.format(**inputs)",
1380+
BaseInterface._check_version_requirements: "[]",
13711381
}
13721382

13731383
INPUT_KEYS = [
@@ -1433,3 +1443,7 @@ def find_super_method(
14331443
f"Could not find super of '{method_name}' method in base classes of "
14341444
f"{super_base}"
14351445
)
1446+
1447+
1448+
def strip_comments(src: str) -> str:
1449+
return re.sub(r"\s*#.*", "", src)

nipype2pydra/interface/shell_command.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,10 @@ def format_arg_code(self):
288288
)
289289
if not body:
290290
return ""
291-
body = self.unwrap_nested_methods(body)
291+
body = self.unwrap_nested_methods(body, inputs_as_dict=True)
292292
body = self.replace_supers(body)
293293

294-
code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):
295-
parsed_inputs = _parse_inputs(inputs) if inputs else {{}}
294+
code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call}
296295
if {val_arg} is None:
297296
return ""
298297
{body}
@@ -324,7 +323,7 @@ def parse_inputs_code(self) -> str:
324323

325324
# Strip out return value
326325
body = re.sub(r"\s*return .*\n", "", body)
327-
body = self.unwrap_nested_methods(body)
326+
body = self.unwrap_nested_methods(body, inputs_as_dict=True)
328327
body = self.replace_supers(body)
329328

330329
code_str = "def _parse_inputs(inputs):\n parsed_inputs = {{}}"
@@ -352,11 +351,10 @@ def defaults_code(self):
352351

353352
if not body:
354353
return ""
355-
body = self.unwrap_nested_methods(body)
354+
body = self.unwrap_nested_methods(body, inputs_as_dict=True)
356355
body = self.replace_supers(body)
357356

358-
code_str = f"""def _gen_filename(name, inputs):
359-
parsed_inputs = _parse_inputs(inputs) if inputs else {{}}
357+
code_str = f"""def _gen_filename(name, inputs):{self.parse_inputs_call}
360358
{body}
361359
"""
362360
# Create separate default function for each input field with genfile, which
@@ -376,31 +374,39 @@ def callables_code(self):
376374

377375
if not self.callable_output_fields:
378376
return ""
379-
380-
body = _strip_doc_string(
381-
inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[-1]
382-
)
377+
if hasattr(self.nipype_interface, "aggregate_outputs"):
378+
func_name = "aggregate_outputs"
379+
body = _strip_doc_string(
380+
inspect.getsource(self.nipype_interface.aggregate_outputs).split(
381+
"\n", 1
382+
)[-1]
383+
)
384+
else:
385+
func_name = "_list_outputs"
386+
body = _strip_doc_string(
387+
inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[
388+
-1
389+
]
390+
)
383391
body = self._process_inputs(body)
384392
body = self._misc_cleanups(body)
385393

386394
if not body:
387395
return ""
388396
body = self.unwrap_nested_methods(
389-
body,
390-
additional_args=CALLABLES_ARGS,
397+
body, additional_args=CALLABLES_ARGS, inputs_as_dict=True
391398
)
392399
body = self.replace_supers(body)
393400

394-
code_str = f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):
395-
parsed_inputs = _parse_inputs(inputs) if inputs else {{}}
401+
code_str = f"""def {func_name}(inputs=None, stdout=None, stderr=None, output_dir=None):{self.parse_inputs_call}
396402
{body}
397403
"""
398404
# Create separate function for each output field in the "callables" section
399405
for output_field in self.callable_output_fields:
400406
output_name = output_field[0]
401407
code_str += (
402408
f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n"
403-
" outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n"
409+
f" outputs = {func_name}(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n"
404410
' return outputs["' + output_name + '"]\n\n'
405411
)
406412
return code_str
@@ -419,9 +425,15 @@ def _process_inputs(self, body: str) -> str:
419425
self.task_name,
420426
)
421427
body = input_re.sub(r"inputs['\1']", body)
422-
body = re.sub(r"self\.(?!inputs)(\w+)", r"parsed_inputs['\1']", body)
428+
body = re.sub(r"self\.(?!inputs)(\w+)\b(?!\()", r"parsed_inputs['\1']", body)
423429
return body
424430

431+
@property
432+
def parse_inputs_call(self):
433+
if not self.parse_inputs_code:
434+
return ""
435+
return "\n _parse_inputs(inputs) if inputs else {}"
436+
425437

426438
def _strip_doc_string(body: str) -> str:
427439
if re.match(r"\s*(\"|')", body):

0 commit comments

Comments
 (0)