Skip to content

Commit 28659bc

Browse files
committed
fixed handling of super methods, added supported for aggregate_outputs and fix bug with inputs handling in _list_outputs
1 parent b550966 commit 28659bc

File tree

4 files changed

+90
-36
lines changed

4 files changed

+90
-36
lines changed

nipype2pydra/interface/base.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
cleanup_function_body,
3232
insert_args_in_signature,
3333
extract_args,
34+
strip_comments,
35+
find_super_method,
3436
)
3537
from ..statements import (
3638
ImportStatement,
@@ -1021,6 +1023,11 @@ def _misc_cleanups(self, body: str) -> str:
10211023
)
10221024
body = re.sub(r"\bruntime\.(stdout|stderr)", r"\1", body)
10231025
body = re.sub(r"\boutputs\.(\w+)", r"outputs['\1']", body)
1026+
body = re.sub(r"getattr\(inputs, ([^)]+)\)", r"inputs[\1]", body)
1027+
body = re.sub(
1028+
r"setattr\(outputs, ([^,]+), ([^)]+)\)", r"outputs[\1] = \2", body
1029+
)
1030+
body = body.replace("TraitError", "KeyError")
10241031
body = body.replace("os.getcwd()", "output_dir")
10251032
return body
10261033

@@ -1209,7 +1216,9 @@ def process_method(
12091216
if method.__name__ in self.method_args:
12101217
args += [
12111218
f"{a}=None"
1212-
for a in (list(self.method_args[method.__name__]) + additional_args)
1219+
for a in (
1220+
list(self.method_args[method.__name__]) + list(additional_args)
1221+
)
12131222
]
12141223
# Insert method args in signature if present
12151224
return_types, method_body = post.split(":", maxsplit=1)
@@ -1291,7 +1300,9 @@ def replace_supers(self, method_body, super_base=None):
12911300
def replace_super(match):
12921301
super_method, base = find_super_method(super_base, match.group(1))
12931302
try:
1294-
return self.SPECIAL_SUPER_MAPPINGS[super_method]
1303+
return self.SPECIAL_SUPER_MAPPINGS[super_method].format(
1304+
args=match.group(2)
1305+
)
12951306
except KeyError:
12961307
try:
12971308
return name_map[match.group(1)] + "(" + match.group(2) + ")"
@@ -1316,7 +1327,9 @@ def unwrap_nested_methods(
13161327
"""
13171328
# Add args to the function signature of method calls
13181329
method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL)
1319-
method_names = [m.__name__ for m in self.referenced_methods]
1330+
method_names = [m.__name__ for m in self.referenced_methods] + list(
1331+
self.INCLUDED_METHODS
1332+
)
13201333
method_body = strip_comments(method_body)
13211334
omitted_methods = {}
13221335
for method_name in set(
@@ -1329,8 +1342,11 @@ def unwrap_nested_methods(
13291342
new_body = splits[0]
13301343
for name, args in zip(splits[1::2], splits[2::2]):
13311344
if name in omitted_methods:
1332-
new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]]
1333-
new_body += extract_args(args)[-1][1:]
1345+
args, post = extract_args(args)[1:]
1346+
new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]].format(
1347+
args=", ".join(args)
1348+
)
1349+
new_body += post[1:] # drop the leading parenthesis
13341350
continue
13351351
# Assign additional return values (which were previously saved to member
13361352
# attributes) to new variables from the method call
@@ -1375,8 +1391,9 @@ def unwrap_nested_methods(
13751391
return cleanup_function_body(method_body)
13761392

13771393
SPECIAL_SUPER_MAPPINGS = {
1378-
CommandLine._list_outputs: "{}",
1394+
CommandLine._list_outputs: "{{}}",
13791395
CommandLine._format_arg: "argstr.format(**inputs)",
1396+
CommandLine._filename_from_source: "{args} + '_generated'",
13801397
BaseInterface._check_version_requirements: "[]",
13811398
}
13821399

@@ -1431,19 +1448,3 @@ def pytest_configure(config):
14311448
else:
14321449
CATCH_CLI_EXCEPTIONS = True
14331450
"""
1434-
1435-
1436-
def find_super_method(
1437-
super_base: type, method_name: str
1438-
) -> ty.Tuple[ty.Callable, type]:
1439-
for base in super_base.__mro__[1:]:
1440-
if method_name in base.__dict__: # Found the match
1441-
return getattr(base, method_name), base
1442-
raise RuntimeError(
1443-
f"Could not find super of '{method_name}' method in base classes of "
1444-
f"{super_base}"
1445-
)
1446-
1447-
1448-
def strip_comments(src: str) -> str:
1449-
return re.sub(r"\s*#.*", "", src)

nipype2pydra/interface/shell_command.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
UsedSymbols,
1414
split_source_into_statements,
1515
INBUILT_NIPYPE_TRAIT_NAMES,
16+
find_super_method,
1617
)
1718
from fileformats.core.mixin import WithClassifiers
1819
from fileformats.generic import File, Directory
@@ -243,6 +244,10 @@ def callable_output_fields(self):
243244
)
244245
]
245246

247+
@property
248+
def callable_output_field_names(self):
249+
return [f[0] for f in self.callable_output_fields]
250+
246251
@cached_property
247252
def _format_arg_body(self):
248253
if "_format_arg" not in self.nipype_interface.__dict__:
@@ -295,7 +300,7 @@ def format_arg_code(self):
295300
if {val_arg} is None:
296301
return ""
297302
{body}
298-
303+
return argstr.format(**inputs)
299304
300305
"""
301306
for field_name in self.formatted_input_field_names:
@@ -326,7 +331,7 @@ def parse_inputs_code(self) -> str:
326331
body = self.unwrap_nested_methods(body, inputs_as_dict=True)
327332
body = self.replace_supers(body)
328333

329-
code_str = "def _parse_inputs(inputs):\n parsed_inputs = {{}}"
334+
code_str = "def _parse_inputs(inputs):\n parsed_inputs = {}"
330335
if re.findall(r"\bargstrs\b", body):
331336
code_str += f"\n argstrs = {self._format_argstrs!r}"
332337
code_str += f"""
@@ -374,40 +379,70 @@ def callables_code(self):
374379

375380
if not self.callable_output_fields:
376381
return ""
377-
if hasattr(self.nipype_interface, "aggregate_outputs"):
382+
code_str = ""
383+
if (
384+
find_super_method(self.nipype_interface, "aggregate_outputs")[1]
385+
is not BaseInterface
386+
):
378387
func_name = "aggregate_outputs"
379388
body = _strip_doc_string(
380389
inspect.getsource(self.nipype_interface.aggregate_outputs).split(
381390
"\n", 1
382391
)[-1]
383392
)
393+
need_list_outputs = bool(re.findall(r"\b_list_outputs\b", body))
394+
body = self._process_inputs(body)
395+
body = self._misc_cleanups(body)
396+
397+
if not body:
398+
return ""
399+
body = self.unwrap_nested_methods(
400+
body, additional_args=CALLABLES_ARGS, inputs_as_dict=True
401+
)
402+
body = self.replace_supers(body)
403+
404+
code_str += f"""def aggregate_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):
405+
inputs = attrs.asdict(inputs){self.parse_inputs_call}
406+
needed_outputs = {self.callable_output_field_names!r}
407+
{body}
408+
409+
410+
"""
411+
inputs_as_dict_call = ""
412+
384413
else:
385414
func_name = "_list_outputs"
415+
inputs_as_dict_call = "\n inputs = attrs.asdict(inputs)"
416+
need_list_outputs = True
417+
418+
if need_list_outputs:
386419
body = _strip_doc_string(
387420
inspect.getsource(self.nipype_interface._list_outputs).split("\n", 1)[
388421
-1
389422
]
390423
)
391-
body = self._process_inputs(body)
392-
body = self._misc_cleanups(body)
424+
body = self._process_inputs(body)
425+
body = self._misc_cleanups(body)
393426

394-
if not body:
395-
return ""
396-
body = self.unwrap_nested_methods(
397-
body, additional_args=CALLABLES_ARGS, inputs_as_dict=True
398-
)
399-
body = self.replace_supers(body)
427+
if not body:
428+
return ""
429+
body = self.unwrap_nested_methods(
430+
body, additional_args=CALLABLES_ARGS, inputs_as_dict=True
431+
)
432+
body = self.replace_supers(body)
400433

401-
code_str = f"""def {func_name}(inputs=None, stdout=None, stderr=None, output_dir=None):{self.parse_inputs_call}
434+
code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{self.parse_inputs_call}
402435
{body}
436+
437+
403438
"""
404439
# Create separate function for each output field in the "callables" section
405440
for output_field in self.callable_output_fields:
406441
output_name = output_field[0]
407442
code_str += (
408443
f"\n\n\ndef {output_name}_callable(output_dir, inputs, stdout, stderr):\n"
409444
f" outputs = {func_name}(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr)\n"
410-
' return outputs["' + output_name + '"]\n\n'
445+
' return outputs.get("' + output_name + '", attrs.NOTHING)\n\n'
411446
)
412447
return code_str
413448

@@ -432,7 +467,7 @@ def _process_inputs(self, body: str) -> str:
432467
def parse_inputs_call(self):
433468
if not self.parse_inputs_code:
434469
return ""
435-
return "\n _parse_inputs(inputs) if inputs else {}"
470+
return "\n parsed_inputs = _parse_inputs(inputs) if inputs else {}"
436471

437472

438473
def _strip_doc_string(body: str) -> str:

nipype2pydra/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
types_converter,
2121
unwrap_nested_type,
2222
get_return_line,
23+
find_super_method,
24+
strip_comments,
2325
INBUILT_NIPYPE_TRAIT_NAMES,
2426
)
2527
from .symbols import ( # noqa: F401

nipype2pydra/utils/misc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,19 @@ def get_return_line(func: ty.Union[str, ty.Callable]) -> str:
539539
if not match:
540540
return None
541541
return match.group(1).strip()
542+
543+
544+
def find_super_method(
545+
super_base: type, method_name: str
546+
) -> ty.Tuple[ty.Callable, type]:
547+
for base in super_base.__mro__[1:]:
548+
if method_name in base.__dict__: # Found the match
549+
return getattr(base, method_name), base
550+
raise RuntimeError(
551+
f"Could not find super of '{method_name}' method in base classes of "
552+
f"{super_base}"
553+
)
554+
555+
556+
def strip_comments(src: str) -> str:
557+
return re.sub(r"\s+#.*", "", src)

0 commit comments

Comments
 (0)