@@ -150,23 +150,23 @@ def __init__(self):
150150 self .entities = {}
151151 self .entity_instances = {}
152152
153- def add_orchestrator (self , fn : task .Orchestrator ) -> str :
153+ def add_orchestrator (self , fn : task .Orchestrator [ TInput , TOutput ] ) -> str :
154154 if fn is None :
155155 raise ValueError ("An orchestrator function argument is required." )
156156
157157 name = task .get_name (fn )
158158 self .add_named_orchestrator (name , fn )
159159 return name
160160
161- def add_named_orchestrator (self , name : str , fn : task .Orchestrator ) -> None :
161+ def add_named_orchestrator (self , name : str , fn : task .Orchestrator [ TInput , TOutput ] ) -> None :
162162 if not name :
163163 raise ValueError ("A non-empty orchestrator name is required." )
164164 if name in self .orchestrators :
165165 raise ValueError (f"A '{ name } ' orchestrator already exists." )
166166
167167 self .orchestrators [name ] = fn
168168
169- def get_orchestrator (self , name : str ) -> Optional [task .Orchestrator ]:
169+ def get_orchestrator (self , name : str ) -> Optional [task .Orchestrator [ Any , Any ] ]:
170170 return self .orchestrators .get (name )
171171
172172 def add_activity (self , fn : task .Activity ) -> str :
@@ -188,16 +188,13 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None:
188188 def get_activity (self , name : str ) -> Optional [task .Activity ]:
189189 return self .activities .get (name )
190190
191- def add_entity (self , fn : task .Entity ) -> str :
191+ def add_entity (self , fn : task .Entity , name : Optional [ str ] = None ) -> str :
192192 if fn is None :
193193 raise ValueError ("An entity function argument is required." )
194194
195- if isinstance (fn , type ) and issubclass (fn , DurableEntity ):
196- name = fn .__name__
197- self .add_named_entity (name , fn )
198- else :
199- name = task .get_name (fn )
200- self .add_named_entity (name , fn )
195+ if name is None :
196+ name = task .get_entity_name (fn )
197+ self .add_named_entity (name , fn )
201198 return name
202199
203200 def add_named_entity (self , name : str , fn : task .Entity ) -> None :
@@ -207,6 +204,7 @@ def add_named_entity(self, name: str, fn: task.Entity) -> None:
207204 raise ValueError (f"A '{ name } ' entity already exists." )
208205
209206 self .entities [name ] = fn
207+ setattr (fn , "__durable_entity_name__" , name )
210208
211209 def get_entity (self , name : str ) -> Optional [task .Entity ]:
212210 return self .entities .get (name )
@@ -362,7 +360,7 @@ def __enter__(self):
362360 def __exit__ (self , type , value , traceback ):
363361 self .stop ()
364362
365- def add_orchestrator (self , fn : task .Orchestrator ) -> str :
363+ def add_orchestrator (self , fn : task .Orchestrator [ TInput , TOutput ] ) -> str :
366364 """Registers an orchestrator function with the worker."""
367365 if self ._is_running :
368366 raise RuntimeError (
@@ -378,13 +376,13 @@ def add_activity(self, fn: task.Activity) -> str:
378376 )
379377 return self ._registry .add_activity (fn )
380378
381- def add_entity (self , fn : task .Entity ) -> str :
379+ def add_entity (self , fn : task .Entity , name : Optional [ str ] = None ) -> str :
382380 """Registers an entity function with the worker."""
383381 if self ._is_running :
384382 raise RuntimeError (
385383 "Entities cannot be added while the worker is running."
386384 )
387- return self ._registry .add_entity (fn )
385+ return self ._registry .add_entity (fn , name )
388386
389387 def use_versioning (self , version : VersioningOptions ) -> None :
390388 """Initializes versioning options for sub-orchestrators and activities."""
@@ -1044,21 +1042,21 @@ def call_activity(
10441042
10451043 def call_entity (
10461044 self ,
1047- entity : EntityInstanceId ,
1045+ entity : EntityInstanceId [ TInput , TOutput ] ,
10481046 operation : str ,
10491047 input : Optional [TInput ] = None ,
1050- ) -> task .CompletableTask :
1048+ ) -> task .CompletableTask [ TOutput ] :
10511049 id = self .next_sequence_number ()
10521050
10531051 self .call_entity_function_helper (
10541052 id , entity , operation , input = input
10551053 )
10561054
1057- return self ._pending_tasks .get (id , task .CompletableTask ())
1055+ return self ._pending_tasks .get (id , task .CompletableTask [ TOutput ] ())
10581056
10591057 def signal_entity (
10601058 self ,
1061- entity_id : EntityInstanceId ,
1059+ entity_id : EntityInstanceId [ TInput , TOutput ] ,
10621060 operation_name : str ,
10631061 input : Optional [TInput ] = None
10641062 ) -> None :
@@ -1068,7 +1066,7 @@ def signal_entity(
10681066 id , entity_id , operation_name , input
10691067 )
10701068
1071- def lock_entities (self , entities : list [EntityInstanceId ]) -> task .CompletableTask [EntityLock ]:
1069+ def lock_entities (self , entities : list [EntityInstanceId [ Any , Any ] ]) -> task .CompletableTask [EntityLock ]:
10721070 id = self .next_sequence_number ()
10731071
10741072 self .lock_entities_function_helper (
@@ -1158,11 +1156,11 @@ def call_activity_function_helper(
11581156 def call_entity_function_helper (
11591157 self ,
11601158 id : Optional [int ],
1161- entity_id : EntityInstanceId ,
1159+ entity_id : EntityInstanceId [ TInput , TOutput ] ,
11621160 operation : str ,
11631161 * ,
11641162 input : Optional [TInput ] = None ,
1165- ):
1163+ ) -> None :
11661164 if id is None :
11671165 id = self .next_sequence_number ()
11681166
@@ -1180,7 +1178,7 @@ def call_entity_function_helper(
11801178 def signal_entity_function_helper (
11811179 self ,
11821180 id : Optional [int ],
1183- entity_id : EntityInstanceId ,
1181+ entity_id : EntityInstanceId [ TInput , TOutput ] ,
11841182 operation : str ,
11851183 input : Optional [TInput ]
11861184 ) -> None :
@@ -1197,7 +1195,7 @@ def signal_entity_function_helper(
11971195 action = ph .new_signal_entity_action (id , entity_id , operation , encoded_input , self .new_uuid ())
11981196 self ._pending_actions [id ] = action
11991197
1200- def lock_entities_function_helper (self , id : int , entities : list [EntityInstanceId ]) -> None :
1198+ def lock_entities_function_helper (self , id : int , entities : list [EntityInstanceId [ Any , Any ] ]) -> None :
12011199 if id is None :
12021200 id = self .next_sequence_number ()
12031201
@@ -1792,7 +1790,7 @@ def process_event(
17921790 # The orchestrator generator function completed
17931791 ctx .set_complete (generatorStopped .value , pb .ORCHESTRATION_STATUS_COMPLETED )
17941792
1795- def _parse_entity_event_sent_input (self , event : pb .HistoryEvent ) -> Tuple [EntityInstanceId , str ]:
1793+ def _parse_entity_event_sent_input (self , event : pb .HistoryEvent ) -> Tuple [EntityInstanceId [ Any , Any ] , str ]:
17961794 try :
17971795 entity_id = EntityInstanceId .parse (event .eventSent .instanceId )
17981796 except ValueError :
@@ -1806,7 +1804,7 @@ def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[Entity
18061804 def _handle_entity_event_raised (self ,
18071805 ctx : _RuntimeOrchestrationContext ,
18081806 event : pb .HistoryEvent ,
1809- entity_id : Optional [EntityInstanceId ],
1807+ entity_id : Optional [EntityInstanceId [ Any , Any ] ],
18101808 task_id : Optional [int ],
18111809 is_lock_event : bool ):
18121810 # This eventRaised represents the result of an entity operation after being translated to the old
@@ -1919,7 +1917,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger):
19191917 def execute (
19201918 self ,
19211919 orchestration_id : str ,
1922- entity_id : EntityInstanceId ,
1920+ entity_id : EntityInstanceId [ TInput , TOutput ] ,
19231921 operation : str ,
19241922 state : StateShim ,
19251923 encoded_input : Optional [str ],
0 commit comments