Skip to content

Commit fcb040d

Browse files
Copilotberndverst
andcommitted
Implement entity execution and worker support with comprehensive tests
Co-authored-by: berndverst <4535280+berndverst@users.noreply.github.com>
1 parent 6d240c0 commit fcb040d

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed

durabletask/worker.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ def __init__(
7575
class _Registry:
7676
orchestrators: dict[str, task.Orchestrator]
7777
activities: dict[str, task.Activity]
78+
entities: dict[str, task.Entity]
7879

7980
def __init__(self):
8081
self.orchestrators = {}
8182
self.activities = {}
83+
self.entities = {}
8284

8385
def add_orchestrator(self, fn: task.Orchestrator) -> str:
8486
if fn is None:
@@ -118,6 +120,25 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None:
118120
def get_activity(self, name: str) -> Optional[task.Activity]:
119121
return self.activities.get(name)
120122

123+
def add_entity(self, fn: task.Entity) -> str:
124+
if fn is None:
125+
raise ValueError("An entity function argument is required.")
126+
127+
name = task.get_name(fn)
128+
self.add_named_entity(name, fn)
129+
return name
130+
131+
def add_named_entity(self, name: str, fn: task.Entity) -> None:
132+
if not name:
133+
raise ValueError("A non-empty entity name is required.")
134+
if name in self.entities:
135+
raise ValueError(f"A '{name}' entity already exists.")
136+
137+
self.entities[name] = fn
138+
139+
def get_entity(self, name: str) -> Optional[task.Entity]:
140+
return self.entities.get(name)
141+
121142

122143
class OrchestratorNotRegisteredError(ValueError):
123144
"""Raised when attempting to start an orchestration that is not registered"""
@@ -131,6 +152,12 @@ class ActivityNotRegisteredError(ValueError):
131152
pass
132153

133154

155+
class EntityNotRegisteredError(ValueError):
156+
"""Raised when attempting to call an entity that is not registered"""
157+
158+
pass
159+
160+
134161
class TaskHubGrpcWorker:
135162
"""A gRPC-based worker for processing durable task orchestrations and activities.
136163
@@ -279,6 +306,14 @@ def add_activity(self, fn: task.Activity) -> str:
279306
)
280307
return self._registry.add_activity(fn)
281308

309+
def add_entity(self, fn: task.Entity) -> str:
310+
"""Registers an entity function with the worker."""
311+
if self._is_running:
312+
raise RuntimeError(
313+
"Entities cannot be added while the worker is running."
314+
)
315+
return self._registry.add_entity(fn)
316+
282317
def start(self):
283318
"""Starts the worker on a background thread and begins listening for work items."""
284319
if self._is_running:
@@ -434,6 +469,13 @@ def stream_reader():
434469
stub,
435470
work_item.completionToken,
436471
)
472+
elif work_item.HasField("entityRequest"):
473+
self._async_worker_manager.submit_activity(
474+
self._execute_entity,
475+
work_item.entityRequest,
476+
stub,
477+
work_item.completionToken,
478+
)
437479
elif work_item.HasField("healthPing"):
438480
pass
439481
else:
@@ -569,6 +611,34 @@ def _execute_activity(
569611
f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
570612
)
571613

614+
def _execute_entity(
615+
self,
616+
req: pb.EntityBatchRequest,
617+
stub: stubs.TaskHubSidecarServiceStub,
618+
completionToken,
619+
):
620+
instance_id = req.instanceId
621+
try:
622+
executor = _EntityExecutor(self._registry, self._logger)
623+
result = executor.execute(req)
624+
result.completionToken = completionToken
625+
except Exception as ex:
626+
self._logger.exception(
627+
f"An error occurred while trying to execute entity '{instance_id}': {ex}"
628+
)
629+
failure_details = ph.new_failure_details(ex)
630+
result = pb.EntityBatchResult(
631+
failureDetails=failure_details,
632+
completionToken=completionToken,
633+
)
634+
635+
try:
636+
stub.CompleteEntityTask(result)
637+
except Exception as ex:
638+
self._logger.exception(
639+
f"Failed to deliver entity response for entity '{instance_id}' to sidecar: {ex}"
640+
)
641+
572642

573643
class _RuntimeOrchestrationContext(task.OrchestrationContext):
574644
_generator: Optional[Generator[task.Task, Any, Any]]
@@ -858,6 +928,36 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
858928

859929
self.set_continued_as_new(new_input, save_events)
860930

931+
def signal_entity(self, entity_id: str, operation_name: str, *,
932+
input: Optional[Any] = None) -> task.Task:
933+
# Create a signal entity action
934+
action = pb.OrchestratorAction()
935+
action.sendEntitySignal.CopyFrom(pb.SendSignalAction(
936+
instanceId=entity_id,
937+
name=operation_name,
938+
input=ph.get_string_value(shared.to_json(input)) if input is not None else None
939+
))
940+
941+
# Entity signals don't return values, so we create a completed task
942+
signal_task = task.CompletableTask()
943+
944+
# Store the action to be executed
945+
task_id = self._next_task_id()
946+
self._pending_actions[task_id] = action
947+
self._pending_tasks[task_id] = signal_task
948+
949+
# Mark as complete since signals don't have return values
950+
signal_task.complete(None)
951+
952+
return signal_task
953+
954+
def call_entity(self, entity_id: str, operation_name: str, *,
955+
input: Optional[Any] = None,
956+
retry_policy: Optional[task.RetryPolicy] = None) -> task.Task:
957+
# For now, entity calls are not directly supported in orchestrations
958+
# This would require additional protobuf support
959+
raise NotImplementedError("Direct entity calls from orchestrations are not yet supported. Use signal_entity instead.")
960+
861961

862962
class ExecutionResults:
863963
actions: list[pb.OrchestratorAction]
@@ -1260,6 +1360,81 @@ def execute(
12601360
return encoded_output
12611361

12621362

1363+
class _EntityExecutor:
1364+
def __init__(self, registry: _Registry, logger: logging.Logger):
1365+
self._registry = registry
1366+
self._logger = logger
1367+
1368+
def execute(self, req: pb.EntityBatchRequest) -> pb.EntityBatchResult:
1369+
"""Executes entity operations and returns the batch result."""
1370+
instance_id = req.instanceId
1371+
self._logger.debug(f"Executing entity batch for '{instance_id}' with {len(req.operations)} operation(s)...")
1372+
1373+
# Parse current entity state
1374+
current_state = shared.from_json(req.entityState.value) if not ph.is_empty(req.entityState) else None
1375+
1376+
# Extract entity type from instance ID (format: entitytype@key)
1377+
entity_type = "Unknown"
1378+
if "@" in instance_id:
1379+
entity_type = instance_id.split("@")[0]
1380+
1381+
results = []
1382+
actions = []
1383+
1384+
for operation in req.operations:
1385+
try:
1386+
# Get the entity function using the entity type from instanceId
1387+
fn = self._registry.get_entity(entity_type)
1388+
if not fn:
1389+
raise EntityNotRegisteredError(f"Entity function named '{entity_type}' was not registered!")
1390+
1391+
# Create entity context
1392+
ctx = task.EntityContext(
1393+
instance_id=instance_id,
1394+
operation_name=operation.operation,
1395+
is_new_entity=(current_state is None)
1396+
)
1397+
ctx.set_state(current_state)
1398+
1399+
# Parse operation input
1400+
operation_input = shared.from_json(operation.input.value) if not ph.is_empty(operation.input) else None
1401+
1402+
# Execute the entity operation
1403+
operation_output = fn(ctx, operation_input)
1404+
1405+
# Update state for next operation
1406+
current_state = ctx.get_state()
1407+
1408+
# Create operation result
1409+
result = pb.OperationResult()
1410+
if operation_output is not None:
1411+
result.success.CopyFrom(pb.OperationResultSuccess(
1412+
result=ph.get_string_value(shared.to_json(operation_output))
1413+
))
1414+
else:
1415+
result.success.CopyFrom(pb.OperationResultSuccess())
1416+
1417+
results.append(result)
1418+
1419+
except Exception as ex:
1420+
self._logger.exception(f"Error executing entity operation '{operation.operation}' on entity type '{entity_type}': {ex}")
1421+
1422+
# Create failure result
1423+
failure_details = ph.new_failure_details(ex)
1424+
result = pb.OperationResult()
1425+
result.failure.CopyFrom(pb.OperationResultFailure(
1426+
failureDetails=failure_details
1427+
))
1428+
results.append(result)
1429+
1430+
# Return batch result
1431+
return pb.EntityBatchResult(
1432+
results=results,
1433+
actions=actions,
1434+
entityState=ph.get_string_value(shared.to_json(current_state)) if current_state is not None else None
1435+
)
1436+
1437+
12631438
def _get_non_determinism_error(
12641439
task_id: int, action_name: str
12651440
) -> task.NonDeterminismError:

tests/durabletask/test_entities.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest
55
from datetime import datetime
66
from durabletask import task
7+
from durabletask import worker as task_worker
78

89

910
class TestEntityTypes(unittest.TestCase):
@@ -108,5 +109,81 @@ def test_entity_query_result_creation(self):
108109
self.assertEqual(result.continuation_token, "next-page-token")
109110

110111

112+
class TestEntityWorkerIntegration(unittest.TestCase):
113+
114+
def test_worker_entity_registration(self):
115+
"""Test that entities can be registered with the worker."""
116+
worker = task_worker.TaskHubGrpcWorker()
117+
118+
def counter_entity(ctx: task.EntityContext, input):
119+
if ctx.operation_name == "increment":
120+
current_count = ctx.get_state() or 0
121+
new_count = current_count + (input or 1)
122+
ctx.set_state(new_count)
123+
return new_count
124+
elif ctx.operation_name == "get":
125+
return ctx.get_state() or 0
126+
elif ctx.operation_name == "reset":
127+
ctx.set_state(0)
128+
return 0
129+
130+
# Test registration
131+
entity_name = worker.add_entity(counter_entity)
132+
self.assertEqual(entity_name, "counter_entity")
133+
134+
# Test that entity is in registry
135+
self.assertIsNotNone(worker._registry.get_entity("counter_entity"))
136+
137+
# Test error for duplicate registration
138+
with self.assertRaises(ValueError):
139+
worker.add_entity(counter_entity)
140+
141+
def test_entity_execution(self):
142+
"""Test entity execution via the EntityExecutor."""
143+
from durabletask.worker import _Registry, _EntityExecutor
144+
import durabletask.internal.orchestrator_service_pb2 as pb
145+
import durabletask.internal.helpers as ph
146+
import logging
147+
148+
# Create registry and register entity
149+
registry = _Registry()
150+
151+
def counter_entity(ctx: task.EntityContext, input):
152+
if ctx.operation_name == "increment":
153+
current_count = ctx.get_state() or 0
154+
new_count = current_count + (input or 1)
155+
ctx.set_state(new_count)
156+
return new_count
157+
elif ctx.operation_name == "get":
158+
return ctx.get_state() or 0
159+
160+
# Register the entity with a specific name
161+
registry.add_named_entity("Counter", counter_entity)
162+
163+
# Create executor
164+
logger = logging.getLogger("test")
165+
executor = _EntityExecutor(registry, logger)
166+
167+
# Create test request
168+
req = pb.EntityBatchRequest()
169+
req.instanceId = "Counter@test-key" # Instance ID with entity type prefix matching registration
170+
req.entityState.CopyFrom(ph.get_string_value("0")) # Initial state
171+
172+
# Add increment operation
173+
operation = pb.OperationRequest()
174+
operation.operation = "increment"
175+
operation.input.CopyFrom(ph.get_string_value("5"))
176+
req.operations.append(operation)
177+
178+
# Execute
179+
result = executor.execute(req)
180+
181+
# Verify result
182+
self.assertEqual(len(result.results), 1)
183+
self.assertTrue(result.results[0].HasField("success"))
184+
self.assertEqual(result.results[0].success.result.value, "5")
185+
self.assertEqual(result.entityState.value, "5")
186+
187+
111188
if __name__ == '__main__':
112189
unittest.main()

0 commit comments

Comments
 (0)