1
- from unittest .mock import patch , ANY
1
+ from unittest .mock import patch , ANY , Mock
2
2
3
+ from durabletask .client import TaskHubGrpcClient
3
4
from durabletask .internal .shared import (DefaultClientInterceptorImpl ,
4
5
get_default_host_address ,
5
6
get_grpc_channel )
7
+ import pytest
6
8
7
9
HOST_ADDRESS = 'localhost:50051'
8
10
METADATA = [('key1' , 'value1' ), ('key2' , 'value2' )]
@@ -85,4 +87,61 @@ def test_grpc_channel_with_host_name_protocol_stripping():
85
87
86
88
prefix = ""
87
89
get_grpc_channel (prefix + host_name , METADATA , True )
88
- mock_secure_channel .assert_called_with (host_name , ANY )
90
+ mock_secure_channel .assert_called_with (host_name , ANY )
91
+
92
+
93
+ @pytest .mark .parametrize ("timeout" , [None , 0 , 5 ])
94
+ def test_wait_for_orchestration_start_timeout (timeout ):
95
+ instance_id = "test-instance"
96
+
97
+ from durabletask .internal .orchestrator_service_pb2 import GetInstanceResponse , \
98
+ OrchestrationState , ORCHESTRATION_STATUS_RUNNING
99
+
100
+ response = GetInstanceResponse ()
101
+ state = OrchestrationState ()
102
+ state .instanceId = instance_id
103
+ state .orchestrationStatus = ORCHESTRATION_STATUS_RUNNING
104
+ response .orchestrationState .CopyFrom (state )
105
+
106
+ c = TaskHubGrpcClient ()
107
+ c ._stub = Mock ()
108
+ c ._stub .WaitForInstanceStart .return_value = response
109
+
110
+ grpc_timeout = None if timeout is None else timeout
111
+ c .wait_for_orchestration_start (instance_id , timeout = grpc_timeout )
112
+
113
+ # Verify WaitForInstanceStart was called with timeout=None
114
+ c ._stub .WaitForInstanceStart .assert_called_once ()
115
+ _ , kwargs = c ._stub .WaitForInstanceStart .call_args
116
+ if timeout is None or timeout == 0 :
117
+ assert kwargs .get ('timeout' ) is None
118
+ else :
119
+ assert kwargs .get ('timeout' ) == timeout
120
+
121
+ @pytest .mark .parametrize ("timeout" , [None , 0 , 5 ])
122
+ def test_wait_for_orchestration_completion_timeout (timeout ):
123
+ instance_id = "test-instance"
124
+
125
+ from durabletask .internal .orchestrator_service_pb2 import GetInstanceResponse , \
126
+ OrchestrationState , ORCHESTRATION_STATUS_COMPLETED
127
+
128
+ response = GetInstanceResponse ()
129
+ state = OrchestrationState ()
130
+ state .instanceId = instance_id
131
+ state .orchestrationStatus = ORCHESTRATION_STATUS_COMPLETED
132
+ response .orchestrationState .CopyFrom (state )
133
+
134
+ c = TaskHubGrpcClient ()
135
+ c ._stub = Mock ()
136
+ c ._stub .WaitForInstanceCompletion .return_value = response
137
+
138
+ grpc_timeout = None if timeout is None else timeout
139
+ c .wait_for_orchestration_completion (instance_id , timeout = grpc_timeout )
140
+
141
+ # Verify WaitForInstanceStart was called with timeout=None
142
+ c ._stub .WaitForInstanceCompletion .assert_called_once ()
143
+ _ , kwargs = c ._stub .WaitForInstanceCompletion .call_args
144
+ if timeout is None or timeout == 0 :
145
+ assert kwargs .get ('timeout' ) is None
146
+ else :
147
+ assert kwargs .get ('timeout' ) == timeout
0 commit comments