12
12
import black .report
13
13
import attrs
14
14
import yaml
15
- from fileformats .core import from_mime , FileSet
15
+ from fileformats .core import from_mime , FileSet , Field
16
16
from .utils import (
17
17
UsedSymbols ,
18
18
split_source_into_statements ,
@@ -114,6 +114,8 @@ def type_repr_(t):
114
114
)
115
115
if t in (ty .Any , ty .Union , ty .List , ty .Tuple ):
116
116
return f"ty.{ t .__name__ } "
117
+ elif issubclass (t , Field ):
118
+ return t .primative .__name__
117
119
elif issubclass (t , FileSet ):
118
120
return t .__name__
119
121
else :
@@ -407,7 +409,7 @@ def exported_outputs(self):
407
409
return (o for o in self .outputs .values () if o .export )
408
410
409
411
def get_input (
410
- self , field_name : str , node_name : ty .Optional [str ] = None
412
+ self , field_name : str , node_name : ty .Optional [str ] = None , create : bool = False
411
413
) -> WorkflowInput :
412
414
"""
413
415
Returns the name of the input field in the workflow for the given node and field
@@ -416,17 +418,21 @@ def get_input(
416
418
try :
417
419
return self ._input_mapping [(node_name , field_name )]
418
420
except KeyError :
419
- inpt_name = (
420
- field_name
421
- if node_name is None or node_name == self .input_node
422
- else f"{ node_name } _{ field_name } "
423
- )
421
+ if node_name is None or node_name == self .input_node :
422
+ inpt_name = field_name
423
+ elif create :
424
+ inpt_name = f"{ node_name } _{ field_name } "
425
+ else :
426
+ raise KeyError (
427
+ f"Unrecognised output corresponding to { node_name } :{ field_name } field, "
428
+ "set create=True to auto-create"
429
+ )
424
430
inpt = WorkflowInput (name = inpt_name , field = field_name , node_name = node_name )
425
431
self .inputs [inpt_name ] = self ._input_mapping [(node_name , field_name )] = inpt
426
432
return inpt
427
433
428
434
def get_output (
429
- self , field_name : str , node_name : ty .Optional [str ] = None
435
+ self , field_name : str , node_name : ty .Optional [str ] = None , create : bool = False
430
436
) -> WorkflowOutput :
431
437
"""
432
438
Returns the name of the input field in the workflow for the given node and field
@@ -435,11 +441,15 @@ def get_output(
435
441
try :
436
442
return self ._output_mapping [(node_name , field_name )]
437
443
except KeyError :
438
- outpt_name = (
439
- field_name
440
- if node_name is None or node_name == self .input_node
441
- else f"{ node_name } _{ field_name } "
442
- )
444
+ if node_name is None or node_name == self .output_node :
445
+ outpt_name = field_name
446
+ elif create :
447
+ outpt_name = f"{ node_name } _{ field_name } "
448
+ else :
449
+ raise KeyError (
450
+ f"Unrecognised output corresponding to { node_name } :{ field_name } field, "
451
+ "set create=True to auto-create"
452
+ )
443
453
outpt = WorkflowOutput (
444
454
name = outpt_name , field = field_name , node_name = node_name
445
455
)
@@ -923,11 +933,11 @@ def prepare_connections(self):
923
933
for node in nodes :
924
934
if isinstance (node , AddNestedWorkflowStatement ):
925
935
exported_inputs .update (
926
- (i .name , self .get_input (i .name , node_name ))
936
+ (i .name , self .get_input (i .name , node_name , create = True ))
927
937
for i in node .nested_workflow .exported_inputs
928
938
)
929
939
exported_outputs .update (
930
- (o .name , self .get_output (o .name , node_name ))
940
+ (o .name , self .get_output (o .name , node_name , create = True ))
931
941
for o in node .nested_workflow .exported_outputs
932
942
)
933
943
for inpt_name , exp_inpt in exported_inputs :
@@ -957,16 +967,20 @@ def prepare_connections(self):
957
967
self .parsed_statements .append (conn_stmt )
958
968
while self ._unprocessed_connections :
959
969
conn = self ._unprocessed_connections .pop ()
960
- if conn . wf_in :
961
- self .get_input (conn .source_out ). out_conns . append ( conn )
962
- else :
970
+ try :
971
+ inpt = self .get_input (conn .source_out , node_name = conn . source_name )
972
+ except KeyError :
963
973
for src_node in self .nodes [conn .source_name ]:
964
974
src_node .add_output_connection (conn )
965
- if conn .wf_out :
966
- self .get_output (conn .target_in ).in_conns .append (conn )
967
975
else :
976
+ inpt .out_conns .append (conn )
977
+ try :
978
+ outpt = self .get_output (conn .target_in , node_name = conn .target_name )
979
+ except KeyError :
968
980
for tgt_node in self .nodes [conn .target_name ]:
969
981
tgt_node .add_input_connection (conn )
982
+ else :
983
+ outpt .in_conns .append (conn )
970
984
971
985
def _parse_statements (self , func_body : str ) -> ty .Tuple [
972
986
ty .List [
0 commit comments