@@ -24,6 +24,8 @@ class TaskConverter:
24
24
output_templates : dict = attrs .field (factory = dict )
25
25
output_callables : dict = attrs .field (factory = dict )
26
26
doctest : dict = attrs .field (factory = dict )
27
+ tests_inputs : list = attrs .field (factory = list )
28
+ tests_outputs : list = attrs .field (factory = list )
27
29
callables_module : ModuleType = attrs .field (
28
30
converter = import_module_from_path , default = None
29
31
)
@@ -39,11 +41,11 @@ def nipype_interface(self) -> nipype.interfaces.base.BaseInterface:
39
41
40
42
@property
41
43
def nipype_input_spec (self ) -> nipype .interfaces .base .BaseInterfaceInputSpec :
42
- return self .nipype_interface .input_spec
44
+ return self .nipype_interface .input_spec ()
43
45
44
46
@property
45
47
def nipype_output_spec (self ) -> nipype .interfaces .base .BaseTraitedSpec :
46
- return self .nipype_interface .input_spec
48
+ return self .nipype_interface .output_spec ()
47
49
48
50
def generate (self , output_file : Path ):
49
51
"""creating pydra input/output spec from nipype specs
@@ -52,18 +54,12 @@ def generate(self, output_file: Path):
52
54
input_fields , inp_templates = self .convert_input_fields ()
53
55
output_fields = self .convert_output_spec (fields_from_template = inp_templates )
54
56
55
- input_spec = specs .SpecInfo (
56
- name = "Input" , fields = input_fields , bases = (specs .ShellSpec ,)
57
- )
58
- output_spec = specs .SpecInfo (
59
- name = "Output" , fields = output_fields , bases = (specs .ShellOutSpec ,)
60
- )
61
-
62
57
testdir = output_file .parent / "tests"
58
+ testdir .mkdir ()
63
59
filename_test = testdir / f"test_spec_{ output_file .name } "
64
60
filename_test_run = testdir / f"test_run_{ output_file .name } "
65
61
66
- self .write_task (output_file , input_spec , output_spec )
62
+ self .write_task (output_file , input_fields , output_fields )
67
63
68
64
self .write_test (filename_test = filename_test )
69
65
self .write_test (filename_test = filename_test_run , run = True )
@@ -127,9 +123,7 @@ def pydra_fld_input(self, field, nm):
127
123
tp_pdr = str
128
124
elif getattr (field , "genfile" ):
129
125
if nm in self .output_templates :
130
- metadata_pdr ["output_file_template" ] = self .interface_spec [
131
- "output_templates"
132
- ][nm ]
126
+ metadata_pdr ["output_file_template" ] = self .output_templates [nm ]
133
127
if tp_pdr in [
134
128
specs .File ,
135
129
specs .Directory ,
@@ -306,16 +300,16 @@ def types_to_names(spec_fields):
306
300
spec_str += "import typing as ty\n "
307
301
spec_str += functions_str
308
302
spec_str += f"input_fields = { input_fields_str } \n "
309
- spec_str += f"{ self .interface_name } _input_spec = specs.SpecInfo(name='Input', fields=input_fields, bases=(specs.ShellSpec,))\n \n "
303
+ spec_str += f"{ self .task_name } _input_spec = specs.SpecInfo(name='Input', fields=input_fields, bases=(specs.ShellSpec,))\n \n "
310
304
spec_str += f"output_fields = { output_fields_str } \n "
311
- spec_str += f"{ self .interface_name } _output_spec = specs.SpecInfo(name='Output', fields=output_fields, bases=(specs.ShellOutSpec,))\n \n "
305
+ spec_str += f"{ self .task_name } _output_spec = specs.SpecInfo(name='Output', fields=output_fields, bases=(specs.ShellOutSpec,))\n \n "
312
306
313
- spec_str += f"class { self .interface_name } (ShellCommandTask):\n "
307
+ spec_str += f"class { self .task_name } (ShellCommandTask):\n "
314
308
if self .doctest :
315
309
spec_str += self .create_doctest ()
316
- spec_str += f" input_spec = { self .interface_name } _input_spec\n "
317
- spec_str += f" output_spec = { self .interface_name } _output_spec\n "
318
- spec_str += f" executable='{ self .cmd } '\n "
310
+ spec_str += f" input_spec = { self .task_name } _input_spec\n "
311
+ spec_str += f" output_spec = { self .task_name } _output_spec\n "
312
+ spec_str += f" executable='{ self .nipype_interface . _cmd } '\n "
319
313
320
314
for tp_repl in self .TYPE_REPLACE :
321
315
spec_str = spec_str .replace (* tp_repl )
@@ -352,18 +346,18 @@ def write_test(self, filename_test, run=False):
352
346
353
347
spec_str = "import os, pytest \n from pathlib import Path\n "
354
348
spec_str += (
355
- f"from ..{ self .interface_name .lower ()} import { self .interface_name } \n \n "
349
+ f"from ..{ self .task_name .lower ()} import { self .task_name } \n \n "
356
350
)
357
351
if run :
358
352
pass
359
353
spec_str += f"@pytest.mark.parametrize('inputs, outputs', { tests_inp_outp } )\n "
360
- spec_str += f"def test_{ self .interface_name } (test_data, inputs, outputs):\n "
354
+ spec_str += f"def test_{ self .task_name } (test_data, inputs, outputs):\n "
361
355
spec_str += " in_file = Path(test_data) / 'test.nii.gz'\n "
362
356
spec_str += " if inputs is None: inputs = {{}}\n "
363
357
spec_str += " for key, val in inputs.items():\n "
364
358
spec_str += " try: inputs[key] = eval(val)\n "
365
359
spec_str += " except: pass\n "
366
- spec_str += f" task = { self .interface_name } (in_file=in_file, **inputs)\n "
360
+ spec_str += f" task = { self .task_name } (in_file=in_file, **inputs)\n "
367
361
spec_str += (
368
362
" assert set(task.generated_output_names) == "
369
363
"set(['return_code', 'stdout', 'stderr'] + outputs)\n "
@@ -392,14 +386,14 @@ def write_test_error(self, input_error):
392
386
spec_str = "\n \n "
393
387
spec_str += f"@pytest.mark.parametrize('inputs, error', { input_error } )\n "
394
388
spec_str += (
395
- f"def test_{ self .interface_name } _exception(test_data, inputs, error):\n "
389
+ f"def test_{ self .task_name } _exception(test_data, inputs, error):\n "
396
390
)
397
391
spec_str += " in_file = Path(test_data) / 'test.nii.gz'\n "
398
392
spec_str += " if inputs is None: inputs = {{}}\n "
399
393
spec_str += " for key, val in inputs.items():\n "
400
394
spec_str += " try: inputs[key] = eval(val)\n "
401
395
spec_str += " except: pass\n "
402
- spec_str += f" task = { self .interface_name } (in_file=in_file, **inputs)\n "
396
+ spec_str += f" task = { self .task_name } (in_file=in_file, **inputs)\n "
403
397
spec_str += " with pytest.raises(eval(error)):\n "
404
398
spec_str += " task.generated_output_names\n "
405
399
@@ -409,7 +403,7 @@ def create_doctest(self):
409
403
"""adding doctests to the interfaces"""
410
404
cmdline = self .doctest .pop ("cmdline" )
411
405
doctest = ' """\n Example\n -------\n '
412
- doctest += f" >>> task = { self .interface_name } ()\n "
406
+ doctest += f" >>> task = { self .task_name } ()\n "
413
407
for key , val in self .doctest .items ():
414
408
if type (val ) is str :
415
409
doctest += f' >>> task.inputs.{ key } = "{ val } "\n '
0 commit comments