@@ -177,24 +177,21 @@ def targets(self):
177
177
178
178
@property
179
179
def wf_in (self ):
180
- if self .source_name is None :
180
+ try :
181
+ self .workflow_converter .get_input_from_conn (self )
182
+ except KeyError :
183
+ return False
184
+ else :
181
185
return True
182
- for inpt in self .workflow_converter .inputs .values ():
183
- if self .target_name == inpt .node_name and str (self .target_in ) == inpt .field :
184
- return True
185
- return False
186
186
187
187
@property
188
188
def wf_out (self ):
189
- if self .target_name is None :
189
+ try :
190
+ self .workflow_converter .get_output_from_conn (self )
191
+ except KeyError :
192
+ return False
193
+ else :
190
194
return True
191
- for output in self .workflow_converter .outputs .values ():
192
- if (
193
- self .source_name == output .node_name
194
- and str (self .source_out ) == output .field
195
- ):
196
- return True
197
- return False
198
195
199
196
@cached_property
200
197
def conditional (self ):
@@ -215,29 +212,11 @@ def workflow_variable(self):
215
212
216
213
@property
217
214
def wf_in_name (self ):
218
- if not self .wf_in :
219
- raise ValueError (
220
- f"Cannot get wf_in_name for { self } as it is not a workflow input"
221
- )
222
- if self .source_name is None :
223
- return (
224
- self .source_out
225
- if not isinstance (self .source_out , DynamicField )
226
- else self .source_out .varname
227
- )
228
- return self .workflow_converter .get_input (self .target_in , self .target_name ).name
215
+ return self .workflow_converter .get_input_from_conn (self ).name
229
216
230
217
@property
231
218
def wf_out_name (self ):
232
- if not self .wf_out :
233
- raise ValueError (
234
- f"Cannot get wf_out_name for { self } as it is not a workflow output"
235
- )
236
- if self .target_name is None :
237
- return self .target_in
238
- return self .workflow_converter .get_output (
239
- self .source_out , self .source_name
240
- ).name
219
+ return self .workflow_converter .get_output_from_conn (self ).name
241
220
242
221
def __str__ (self ):
243
222
if not self .include :
@@ -274,7 +253,7 @@ def __str__(self):
274
253
# to add an "identity" node to pass it through
275
254
intf_name = f"{ base_task_name } _identity"
276
255
code_str += (
277
- f"{ self .indent } @pydra.mark.task\n "
256
+ f"\n { self .indent } @pydra.mark.task\n "
278
257
f"{ self .indent } def { intf_name } ({ self .wf_in_name } : ty.Any) -> ty.Any:\n "
279
258
f"{ self .indent } return { self .wf_in_name } \n \n "
280
259
f"{ self .indent } { self .workflow_variable } .add("
@@ -669,11 +648,29 @@ def add_input_connection(self, conn: ConnectionStatement):
669
648
else :
670
649
target_in = conn .target_in
671
650
target_name = None
672
- if target_name == self .nested_workflow .input_node :
651
+ # Check for replacements for the given target field
652
+ replacements = [
653
+ i
654
+ for i in self .nested_workflow .inputs .values ()
655
+ if any (n == target_name and f == target_in for n , f in i .replaces )
656
+ ]
657
+ if len (replacements ) > 1 :
658
+ raise ValueError (
659
+ f"Multiple inputs found for replacements of '{ target_in } ' "
660
+ f"field in '{ target_name } ' node in '{ self .name } ' workflow: "
661
+ + ", " .join (str (m ) for m in replacements )
662
+ )
663
+ elif len (replacements ) == 1 :
664
+ nested_input = replacements [0 ]
673
665
target_name = None
674
- nested_input = self .nested_workflow .get_input (
675
- target_in , node_name = target_name , create = True
676
- )
666
+ else :
667
+ # If no replacements, create an input for the nested workflow
668
+ if target_name == self .nested_workflow .input_node :
669
+ target_name = None
670
+ nested_input = self .nested_workflow .make_input (
671
+ target_in ,
672
+ node_name = target_name ,
673
+ )
677
674
conn .target_in = nested_input .name
678
675
super ().add_input_connection (conn )
679
676
if target_name :
@@ -716,11 +713,26 @@ def add_output_connection(self, conn: ConnectionStatement):
716
713
else :
717
714
source_out = conn .source_out
718
715
source_name = None
719
- if source_name == self .nested_workflow .output_node :
716
+ replacements = [
717
+ o
718
+ for o in self .nested_workflow .outputs .values ()
719
+ if any (n == source_name and f == source_out for n , f in o .replaces )
720
+ ]
721
+ if len (replacements ) > 1 :
722
+ raise KeyError (
723
+ f"Multiple outputs found for replacements of '{ source_out } ' "
724
+ f"field in '{ source_name } ' node in '{ self .name } ' workflow: "
725
+ + ", " .join (str (m ) for m in replacements )
726
+ )
727
+ elif len (replacements ) == 1 :
728
+ nested_output = replacements [0 ]
720
729
source_name = None
721
- nested_output = self .nested_workflow .get_output (
722
- source_out , node_name = source_name , create = True
723
- )
730
+ else :
731
+ if source_name == self .nested_workflow .output_node :
732
+ source_name = None
733
+ nested_output = self .nested_workflow .make_output (
734
+ source_out , node_name = source_name
735
+ )
724
736
conn .source_out = nested_output .name
725
737
super ().add_output_connection (conn )
726
738
if source_name :
@@ -759,7 +771,7 @@ def __str__(self):
759
771
parts = self .attribute .split ("." )
760
772
nested_node_name = parts [2 ]
761
773
attribute_name = parts [3 ]
762
- target_in = nested_wf .get_input (attribute_name , nested_node_name ).name
774
+ target_in = nested_wf .make_input (attribute_name , nested_node_name ).name
763
775
attribute = "." .join (parts [:2 ] + [target_in ] + parts [4 :])
764
776
workflow_variable = self .nodes [0 ].workflow_variable
765
777
assert (n .workflow_variable == workflow_variable for n in self .nodes )
@@ -782,6 +794,10 @@ def matches(cls, stmt, node_names: ty.List[str]) -> bool:
782
794
return False
783
795
return bool (cls .match_re (node_names ).match (stmt ))
784
796
797
+ @property
798
+ def conditional (self ):
799
+ return len (self .indent ) != 4
800
+
785
801
@classmethod
786
802
def parse (
787
803
cls , statement : str , workflow_converter : "WorkflowConverter"
0 commit comments