Skip to content

Commit 9edf88c

Browse files
committed
Add support to RestartInstance to backend, unskip tests
1 parent 00d6f7b commit 9edf88c

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

durabletask/testing/in_memory_backend.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,40 @@ def PurgeInstances(self, request: pb.PurgeInstancesRequest, context):
395395
isComplete=wrappers_pb2.BoolValue(value=True),
396396
)
397397

398+
def RestartInstance(self, request: pb.RestartInstanceRequest, context):
399+
"""Restarts a completed orchestration instance."""
400+
with self._lock:
401+
instance = self._instances.get(request.instanceId)
402+
if not instance:
403+
context.abort(
404+
grpc.StatusCode.NOT_FOUND,
405+
f"Orchestration instance '{request.instanceId}' not found")
406+
return pb.RestartInstanceResponse()
407+
408+
if not self._is_terminal_status(instance.status):
409+
context.abort(
410+
grpc.StatusCode.FAILED_PRECONDITION,
411+
f"Orchestration instance '{request.instanceId}' is not in a terminal state")
412+
return pb.RestartInstanceResponse()
413+
414+
name = instance.name
415+
original_input = instance.input
416+
417+
if request.restartWithNewInstanceId:
418+
new_instance_id = uuid.uuid4().hex
419+
else:
420+
new_instance_id = request.instanceId
421+
# Remove the old instance so we can recreate it
422+
del self._instances[request.instanceId]
423+
self._orchestration_queue_set.discard(request.instanceId)
424+
self._state_waiters.pop(request.instanceId, None)
425+
426+
self._create_instance_internal(new_instance_id, name, original_input)
427+
428+
self._logger.info(
429+
f"Restarted instance '{request.instanceId}' as '{new_instance_id}'")
430+
return pb.RestartInstanceResponse(instanceId=new_instance_id)
431+
398432
def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
399433
"""Streams work items to the worker (orchestration and activity work items)."""
400434
self._logger.info("Worker connected and requesting work items")

tests/durabletask/test_orchestration_e2e.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@ def child(ctx: task.OrchestrationContext, _):
343343
assert state is None
344344

345345

346-
@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance")
347346
def test_restart_with_same_instance_id():
348347
def orchestrator(ctx: task.OrchestrationContext, _):
349348
result = yield ctx.call_activity(say_hello, input="World")
@@ -353,12 +352,12 @@ def say_hello(ctx: task.ActivityContext, input: str):
353352
return f"Hello, {input}!"
354353

355354
# Start a worker, which will connect to the sidecar in a background thread
356-
with worker.TaskHubGrpcWorker() as w:
355+
with worker.TaskHubGrpcWorker(host_address=HOST) as w:
357356
w.add_orchestrator(orchestrator)
358357
w.add_activity(say_hello)
359358
w.start()
360359

361-
task_hub_client = client.TaskHubGrpcClient()
360+
task_hub_client = client.TaskHubGrpcClient(host_address=HOST)
362361
id = task_hub_client.schedule_new_orchestration(orchestrator)
363362
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
364363
assert state is not None
@@ -375,7 +374,6 @@ def say_hello(ctx: task.ActivityContext, input: str):
375374
assert state.serialized_output == json.dumps("Hello, World!")
376375

377376

378-
@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance")
379377
def test_restart_with_new_instance_id():
380378
def orchestrator(ctx: task.OrchestrationContext, _):
381379
result = yield ctx.call_activity(say_hello, input="World")
@@ -385,12 +383,12 @@ def say_hello(ctx: task.ActivityContext, input: str):
385383
return f"Hello, {input}!"
386384

387385
# Start a worker, which will connect to the sidecar in a background thread
388-
with worker.TaskHubGrpcWorker() as w:
386+
with worker.TaskHubGrpcWorker(host_address=HOST) as w:
389387
w.add_orchestrator(orchestrator)
390388
w.add_activity(say_hello)
391389
w.start()
392390

393-
task_hub_client = client.TaskHubGrpcClient()
391+
task_hub_client = client.TaskHubGrpcClient(host_address=HOST)
394392
id = task_hub_client.schedule_new_orchestration(orchestrator)
395393
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
396394
assert state is not None

0 commit comments

Comments
 (0)