@@ -75,10 +75,12 @@ def __init__(
7575class _Registry :
7676 orchestrators : dict [str , task .Orchestrator ]
7777 activities : dict [str , task .Activity ]
78+ entities : dict [str , task .Entity ]
7879
7980 def __init__ (self ):
8081 self .orchestrators = {}
8182 self .activities = {}
83+ self .entities = {}
8284
8385 def add_orchestrator (self , fn : task .Orchestrator ) -> str :
8486 if fn is None :
@@ -118,6 +120,25 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None:
118120 def get_activity (self , name : str ) -> Optional [task .Activity ]:
119121 return self .activities .get (name )
120122
123+ def add_entity (self , fn : task .Entity ) -> str :
124+ if fn is None :
125+ raise ValueError ("An entity function argument is required." )
126+
127+ name = task .get_name (fn )
128+ self .add_named_entity (name , fn )
129+ return name
130+
131+ def add_named_entity (self , name : str , fn : task .Entity ) -> None :
132+ if not name :
133+ raise ValueError ("A non-empty entity name is required." )
134+ if name in self .entities :
135+ raise ValueError (f"A '{ name } ' entity already exists." )
136+
137+ self .entities [name ] = fn
138+
139+ def get_entity (self , name : str ) -> Optional [task .Entity ]:
140+ return self .entities .get (name )
141+
121142
122143class OrchestratorNotRegisteredError (ValueError ):
123144 """Raised when attempting to start an orchestration that is not registered"""
@@ -131,6 +152,12 @@ class ActivityNotRegisteredError(ValueError):
131152 pass
132153
133154
155+ class EntityNotRegisteredError (ValueError ):
156+ """Raised when attempting to call an entity that is not registered"""
157+
158+ pass
159+
160+
134161class TaskHubGrpcWorker :
135162 """A gRPC-based worker for processing durable task orchestrations and activities.
136163
@@ -279,6 +306,14 @@ def add_activity(self, fn: task.Activity) -> str:
279306 )
280307 return self ._registry .add_activity (fn )
281308
309+ def add_entity (self , fn : task .Entity ) -> str :
310+ """Registers an entity function with the worker."""
311+ if self ._is_running :
312+ raise RuntimeError (
313+ "Entities cannot be added while the worker is running."
314+ )
315+ return self ._registry .add_entity (fn )
316+
282317 def start (self ):
283318 """Starts the worker on a background thread and begins listening for work items."""
284319 if self ._is_running :
@@ -434,6 +469,13 @@ def stream_reader():
434469 stub ,
435470 work_item .completionToken ,
436471 )
472+ elif work_item .HasField ("entityRequest" ):
473+ self ._async_worker_manager .submit_activity (
474+ self ._execute_entity ,
475+ work_item .entityRequest ,
476+ stub ,
477+ work_item .completionToken ,
478+ )
437479 elif work_item .HasField ("healthPing" ):
438480 pass
439481 else :
@@ -569,6 +611,34 @@ def _execute_activity(
569611 f"Failed to deliver activity response for '{ req .name } #{ req .taskId } ' of orchestration ID '{ instance_id } ' to sidecar: { ex } "
570612 )
571613
614+ def _execute_entity (
615+ self ,
616+ req : pb .EntityBatchRequest ,
617+ stub : stubs .TaskHubSidecarServiceStub ,
618+ completionToken ,
619+ ):
620+ instance_id = req .instanceId
621+ try :
622+ executor = _EntityExecutor (self ._registry , self ._logger )
623+ result = executor .execute (req )
624+ result .completionToken = completionToken
625+ except Exception as ex :
626+ self ._logger .exception (
627+ f"An error occurred while trying to execute entity '{ instance_id } ': { ex } "
628+ )
629+ failure_details = ph .new_failure_details (ex )
630+ result = pb .EntityBatchResult (
631+ failureDetails = failure_details ,
632+ completionToken = completionToken ,
633+ )
634+
635+ try :
636+ stub .CompleteEntityTask (result )
637+ except Exception as ex :
638+ self ._logger .exception (
639+ f"Failed to deliver entity response for entity '{ instance_id } ' to sidecar: { ex } "
640+ )
641+
572642
573643class _RuntimeOrchestrationContext (task .OrchestrationContext ):
574644 _generator : Optional [Generator [task .Task , Any , Any ]]
@@ -858,6 +928,36 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
858928
859929 self .set_continued_as_new (new_input , save_events )
860930
931+ def signal_entity (self , entity_id : str , operation_name : str , * ,
932+ input : Optional [Any ] = None ) -> task .Task :
933+ # Create a signal entity action
934+ action = pb .OrchestratorAction ()
935+ action .sendEntitySignal .CopyFrom (pb .SendSignalAction (
936+ instanceId = entity_id ,
937+ name = operation_name ,
938+ input = ph .get_string_value (shared .to_json (input )) if input is not None else None
939+ ))
940+
941+ # Entity signals don't return values, so we create a completed task
942+ signal_task = task .CompletableTask ()
943+
944+ # Store the action to be executed
945+ task_id = self ._next_task_id ()
946+ self ._pending_actions [task_id ] = action
947+ self ._pending_tasks [task_id ] = signal_task
948+
949+ # Mark as complete since signals don't have return values
950+ signal_task .complete (None )
951+
952+ return signal_task
953+
954+ def call_entity (self , entity_id : str , operation_name : str , * ,
955+ input : Optional [Any ] = None ,
956+ retry_policy : Optional [task .RetryPolicy ] = None ) -> task .Task :
957+ # For now, entity calls are not directly supported in orchestrations
958+ # This would require additional protobuf support
959+ raise NotImplementedError ("Direct entity calls from orchestrations are not yet supported. Use signal_entity instead." )
960+
861961
862962class ExecutionResults :
863963 actions : list [pb .OrchestratorAction ]
@@ -1260,6 +1360,81 @@ def execute(
12601360 return encoded_output
12611361
12621362
1363+ class _EntityExecutor :
1364+ def __init__ (self , registry : _Registry , logger : logging .Logger ):
1365+ self ._registry = registry
1366+ self ._logger = logger
1367+
1368+ def execute (self , req : pb .EntityBatchRequest ) -> pb .EntityBatchResult :
1369+ """Executes entity operations and returns the batch result."""
1370+ instance_id = req .instanceId
1371+ self ._logger .debug (f"Executing entity batch for '{ instance_id } ' with { len (req .operations )} operation(s)..." )
1372+
1373+ # Parse current entity state
1374+ current_state = shared .from_json (req .entityState .value ) if not ph .is_empty (req .entityState ) else None
1375+
1376+ # Extract entity type from instance ID (format: entitytype@key)
1377+ entity_type = "Unknown"
1378+ if "@" in instance_id :
1379+ entity_type = instance_id .split ("@" )[0 ]
1380+
1381+ results = []
1382+ actions = []
1383+
1384+ for operation in req .operations :
1385+ try :
1386+ # Get the entity function using the entity type from instanceId
1387+ fn = self ._registry .get_entity (entity_type )
1388+ if not fn :
1389+ raise EntityNotRegisteredError (f"Entity function named '{ entity_type } ' was not registered!" )
1390+
1391+ # Create entity context
1392+ ctx = task .EntityContext (
1393+ instance_id = instance_id ,
1394+ operation_name = operation .operation ,
1395+ is_new_entity = (current_state is None )
1396+ )
1397+ ctx .set_state (current_state )
1398+
1399+ # Parse operation input
1400+ operation_input = shared .from_json (operation .input .value ) if not ph .is_empty (operation .input ) else None
1401+
1402+ # Execute the entity operation
1403+ operation_output = fn (ctx , operation_input )
1404+
1405+ # Update state for next operation
1406+ current_state = ctx .get_state ()
1407+
1408+ # Create operation result
1409+ result = pb .OperationResult ()
1410+ if operation_output is not None :
1411+ result .success .CopyFrom (pb .OperationResultSuccess (
1412+ result = ph .get_string_value (shared .to_json (operation_output ))
1413+ ))
1414+ else :
1415+ result .success .CopyFrom (pb .OperationResultSuccess ())
1416+
1417+ results .append (result )
1418+
1419+ except Exception as ex :
1420+ self ._logger .exception (f"Error executing entity operation '{ operation .operation } ' on entity type '{ entity_type } ': { ex } " )
1421+
1422+ # Create failure result
1423+ failure_details = ph .new_failure_details (ex )
1424+ result = pb .OperationResult ()
1425+ result .failure .CopyFrom (pb .OperationResultFailure (
1426+ failureDetails = failure_details
1427+ ))
1428+ results .append (result )
1429+
1430+ # Return batch result
1431+ return pb .EntityBatchResult (
1432+ results = results ,
1433+ actions = actions ,
1434+ entityState = ph .get_string_value (shared .to_json (current_state )) if current_state is not None else None
1435+ )
1436+
1437+
12631438def _get_non_determinism_error (
12641439 task_id : int , action_name : str
12651440) -> task .NonDeterminismError :
0 commit comments