@@ -71,6 +71,7 @@ class WorkflowInterfaceField:
71
71
},
72
72
)
73
73
node_name : ty .Optional [str ] = attrs .field (
74
+ default = None ,
74
75
metadata = {
75
76
"help" : "The name of the node that the input/output is connected to" ,
76
77
},
@@ -130,7 +131,7 @@ def type_repr_(t):
130
131
elif issubclass (t , Field ):
131
132
return t .primitive .__name__
132
133
elif issubclass (t , FileSet ):
133
- return t .__name__
134
+ return t .type_name
134
135
elif t .__module__ == "builtins" :
135
136
return t .__name__
136
137
else :
@@ -764,7 +765,9 @@ def write(
764
765
),
765
766
converted_code = self .test_code ,
766
767
used = self .test_used ,
767
- additional_imports = self .input_output_imports ,
768
+ additional_imports = (
769
+ self .input_output_imports + parse_imports ("import pytest" )
770
+ ),
768
771
)
769
772
770
773
conftest_fspath = test_module_fspath .parent / "conftest.py"
@@ -931,22 +934,51 @@ def parsed_statements(self):
931
934
def test_code (self ):
932
935
args_str = ", " .join (f"{ n } ={ v } " for n , v in self .test_inputs .items ())
933
936
934
- return f"""
937
+ code_str = f"""
935
938
936
- def test_{ self .name } ():
939
+
940
+ def test_{ self .name } _build():
937
941
workflow = { self .name } ({ args_str } )
938
942
assert isinstance(workflow, Workflow)
939
943
"""
940
944
945
+ inputs_dict = {}
946
+ for inpt in self .inputs .values ():
947
+ if issubclass (inpt .type , FileSet ):
948
+ inputs_dict [inpt .name ] = inpt .type .type_name + ".sample()"
949
+ elif inpt .name in self .test_inputs :
950
+ inputs_dict [inpt .name ] = self .test_inputs [inpt .name ]
951
+ args_str = ", " .join (f"{ n } ={ v } " for n , v in inputs_dict .items ())
952
+
953
+ code_str += f"""
954
+
955
+ @pytest.mark.skip(reason="Appropriate inputs for this workflow haven't been specified yet")
956
+ def test_{ self .name } _run():
957
+ workflow = { self .name } ({ args_str } )
958
+ result = workflow(plugin='serial')
959
+ print(result.out)
960
+ """
961
+ return code_str
962
+
941
963
@property
942
964
def test_used (self ):
965
+ nonstd_types = [
966
+ i .type for i in self .inputs .values () if issubclass (i .type , FileSet )
967
+ ]
968
+ nonstd_type_imports = []
969
+ for tp in itertools .chain (* (unwrap_nested_type (t ) for t in nonstd_types )):
970
+ nonstd_type_imports .append (ImportStatement .from_object (tp ))
971
+
943
972
return UsedSymbols (
944
973
module_name = self .nipype_module .__name__ ,
945
- imports = parse_imports (
946
- [
947
- f"from { self .output_module } import { self .name } " ,
948
- "from pydra.engine import Workflow" ,
949
- ]
974
+ imports = (
975
+ nonstd_type_imports
976
+ + parse_imports (
977
+ [
978
+ f"from { self .output_module } import { self .name } " ,
979
+ "from pydra.engine import Workflow" ,
980
+ ]
981
+ )
950
982
),
951
983
)
952
984
0 commit comments