@@ -476,7 +476,7 @@ def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditi
476
476
default_value = 0
477
477
flags = 0
478
478
479
- status , _ , graph , _ , _ = handle_return (driver .cuStreamGetCaptureInfo (self ._mnff .stream .handle ))
479
+ status , _ , graph , * _ , _ = handle_return (driver .cuStreamGetCaptureInfo (self ._mnff .stream .handle ))
480
480
if status != driver .CUstreamCaptureStatus .CU_STREAM_CAPTURE_STATUS_ACTIVE :
481
481
raise RuntimeError ("Cannot create a conditional handle when graph is not being built" )
482
482
@@ -486,20 +486,22 @@ def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditi
486
486
487
487
def _cond_with_params (self , node_params ) -> GraphBuilder :
488
488
# Get current capture info to ensure we're in a valid state
489
- status , _ , graph , dependencies , num_dependencies = handle_return (
489
+ status , _ , graph , * deps_info , num_dependencies = handle_return (
490
490
driver .cuStreamGetCaptureInfo (self ._mnff .stream .handle )
491
491
)
492
492
if status != driver .CUstreamCaptureStatus .CU_STREAM_CAPTURE_STATUS_ACTIVE :
493
493
raise RuntimeError ("Cannot add conditional node when not actively capturing" )
494
494
495
495
# Add the conditional node to the graph
496
- node = handle_return (driver .cuGraphAddNode (graph , dependencies , num_dependencies , node_params ))
496
+ deps_info_update = [
497
+ [handle_return (driver .cuGraphAddNode (graph , * deps_info , num_dependencies , node_params ))]
498
+ ] + [None ] * (len (deps_info ) - 1 )
497
499
498
500
# Update the stream's capture dependencies
499
501
handle_return (
500
502
driver .cuStreamUpdateCaptureDependencies (
501
503
self ._mnff .stream .handle ,
502
- [ node ] , # dependencies
504
+ * deps_info_update , # dependencies, edgeData
503
505
1 , # numDependencies
504
506
driver .CUstreamUpdateCaptureDependencies_flags .CU_STREAM_SET_CAPTURE_DEPENDENCIES ,
505
507
)
@@ -677,17 +679,23 @@ def add_child(self, child_graph: GraphBuilder):
677
679
raise ValueError ("Parent graph is not being built." )
678
680
679
681
stream_handle = self ._mnff .stream .handle
680
- _ , _ , graph_out , dependencies_out , num_dependencies_out = handle_return (
682
+ _ , _ , graph_out , * deps_info_out , num_dependencies_out = handle_return (
681
683
driver .cuStreamGetCaptureInfo (stream_handle )
682
684
)
683
685
684
- child_node = handle_return (
685
- driver .cuGraphAddChildGraphNode (graph_out , dependencies_out , num_dependencies_out , child_graph ._mnff .graph )
686
- )
686
+ deps_info_update = [
687
+ [
688
+ handle_return (
689
+ driver .cuGraphAddChildGraphNode (
690
+ graph_out , deps_info_out [0 ], num_dependencies_out , child_graph ._mnff .graph
691
+ )
692
+ )
693
+ ]
694
+ ] + [None ] * (len (deps_info_out ) - 1 )
687
695
handle_return (
688
696
driver .cuStreamUpdateCaptureDependencies (
689
697
stream_handle ,
690
- [ child_node ],
698
+ * deps_info_update , # dependencies, edgeData
691
699
1 ,
692
700
driver .CUstreamUpdateCaptureDependencies_flags .CU_STREAM_SET_CAPTURE_DEPENDENCIES ,
693
701
)
0 commit comments