Skip to content

Commit 9b461c3

Browse files
committed
Improve entity typing
1 parent 1939eea commit 9b461c3

File tree

7 files changed

+106
-46
lines changed

7 files changed

+106
-46
lines changed

durabletask/client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import grpc
1212
from google.protobuf import wrappers_pb2
1313

14-
from durabletask.entities import EntityInstanceId
15-
from durabletask.entities.entity_metadata import EntityMetadata
14+
from durabletask.entities import EntityInstanceId, EntityMetadata
1615
import durabletask.internal.helpers as helpers
1716
import durabletask.internal.orchestrator_service_pb2 as pb
1817
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
@@ -230,7 +229,10 @@ def purge_orchestration(self, instance_id: str, recursive: bool = True):
230229
self._logger.info(f"Purging instance '{instance_id}'.")
231230
self._stub.PurgeInstances(req)
232231

233-
def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None):
232+
def signal_entity(self,
233+
entity_instance_id: EntityInstanceId[TInput, TOutput],
234+
operation_name: str,
235+
input: Optional[TInput] = None):
234236
req = pb.SignalEntityRequest(
235237
instanceId=str(entity_instance_id),
236238
name=operation_name,
@@ -244,7 +246,7 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: st
244246
self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
245247

246248
def get_entity(self,
247-
entity_instance_id: EntityInstanceId,
249+
entity_instance_id: EntityInstanceId[Any, Any],
248250
include_state: bool = True
249251
) -> Optional[EntityMetadata]:
250252
req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)

durabletask/entities/durable_entity.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from durabletask.entities.entity_instance_id import EntityInstanceId
55

66
TState = TypeVar("TState")
7+
TInput = TypeVar("TInput")
78

89

910
class DurableEntity:
@@ -49,7 +50,10 @@ def set_state(self, state: Any):
4950
"""
5051
self.entity_context.set_state(state)
5152

52-
def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None:
53+
def signal_entity(self,
54+
entity_instance_id: EntityInstanceId[TInput, Any],
55+
operation: str,
56+
input: Optional[TInput] = None) -> None:
5357
"""Signal another entity to perform an operation.
5458
5559
Parameters

durabletask/entities/entity_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import durabletask.internal.orchestrator_service_pb2 as pb
88

99
TState = TypeVar("TState")
10+
TInput = TypeVar("TInput")
1011

1112

1213
class EntityContext:
@@ -81,7 +82,7 @@ def set_state(self, new_state: Any):
8182
"""
8283
self._state.set_state(new_state)
8384

84-
def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Optional[Any] = None) -> None:
85+
def signal_entity(self, entity_instance_id: EntityInstanceId[TInput, Any], operation: str, input: Optional[TInput] = None) -> None:
8586
"""Signal another entity to perform an operation.
8687
8788
Parameters

durabletask/entities/entity_instance_id.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,54 @@
1-
class EntityInstanceId:
2-
def __init__(self, entity: str, key: str):
3-
self.entity = entity
4-
self.key = key
1+
from typing import Any, Callable, TypeVar, Union, overload, TYPE_CHECKING
2+
3+
if TYPE_CHECKING:
4+
from durabletask import task
5+
from durabletask.entities.durable_entity import DurableEntity
6+
from durabletask.entities.entity_context import EntityContext
7+
8+
9+
TInput = TypeVar('TInput')
10+
TOutput = TypeVar('TOutput')
11+
12+
13+
class EntityInstanceId[TInput, TOutput]:
14+
@overload
15+
def __new__(
16+
cls,
17+
entity: Callable[[EntityContext, TInput], TOutput],
18+
key: str
19+
) -> "EntityInstanceId[TInput, TOutput]": ...
20+
21+
@overload
22+
def __new__(
23+
cls,
24+
entity: type[DurableEntity],
25+
key: str
26+
) -> "EntityInstanceId[Any, Any]": ...
27+
28+
@overload
29+
def __new__(
30+
cls,
31+
entity: str,
32+
key: str
33+
) -> "EntityInstanceId[Any, Any]": ...
34+
35+
def __new__(
36+
cls,
37+
entity: Union[task.Entity[TInput, TOutput], str],
38+
key: str
39+
) -> "EntityInstanceId[Any, Any]":
40+
return super().__new__(cls)
41+
42+
def __init__(
43+
self,
44+
entity: Union[task.Entity[TInput, TOutput], str],
45+
key: str
46+
):
47+
if not isinstance(entity, str):
48+
from durabletask import task
49+
entity = task.get_entity_name(entity)
50+
self.entity: str = entity
51+
self.key: str = key
552

653
def __str__(self) -> str:
754
return f"@{self.entity}@{self.key}"
@@ -17,7 +64,7 @@ def __lt__(self, other):
1764
return str(self) < str(other)
1865

1966
@staticmethod
20-
def parse(entity_id: str) -> "EntityInstanceId":
67+
def parse(entity_id: str) -> "EntityInstanceId[Any, Any]":
2168
"""Parse a string representation of an entity ID into an EntityInstanceId object.
2269
2370
Parameters

durabletask/entities/entity_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class EntityMetadata:
2424
"""
2525

2626
def __init__(self,
27-
id: EntityInstanceId,
27+
id: EntityInstanceId[Any, Any],
2828
last_modified: datetime,
2929
backlog_queue_size: int,
3030
locked_by: str,

durabletask/task.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
140140

141141
@abstractmethod
142142
def call_entity(self,
143-
entity: EntityInstanceId,
143+
entity: EntityInstanceId[TInput, TOutput],
144144
operation: str,
145-
input: Optional[TInput] = None) -> CompletableTask:
145+
input: Optional[TInput] = None) -> CompletableTask[TOutput]:
146146
"""Schedule entity function for execution.
147147
148148
Parameters
149149
----------
150-
entity: EntityInstanceId
150+
entity: EntityInstanceId[TInput, TOutput]
151151
The ID of the entity instance to call.
152152
operation: str
153153
The name of the operation to invoke on the entity.
@@ -164,15 +164,15 @@ def call_entity(self,
164164
@abstractmethod
165165
def signal_entity(
166166
self,
167-
entity_id: EntityInstanceId,
167+
entity_id: EntityInstanceId[TInput, TOutput],
168168
operation_name: str,
169169
input: Optional[TInput] = None
170170
) -> None:
171171
"""Signal an entity function for execution.
172172
173173
Parameters
174174
----------
175-
entity_id: EntityInstanceId
175+
entity_id: EntityInstanceId[TInput, TOutput]
176176
The ID of the entity instance to signal.
177177
operation_name: str
178178
The name of the operation to invoke on the entity.
@@ -182,7 +182,7 @@ def signal_entity(
182182
pass
183183

184184
@abstractmethod
185-
def lock_entities(self, entities: list[EntityInstanceId]) -> CompletableTask[EntityLock]:
185+
def lock_entities(self, entities: list[EntityInstanceId[Any, Any]]) -> CompletableTask[EntityLock]:
186186
"""Creates a Task object that locks the specified entity instances.
187187
188188
The locks will be acquired the next time the orchestrator yields.
@@ -191,7 +191,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> CompletableTask[Ent
191191
192192
Parameters
193193
----------
194-
entities: list[EntityInstanceId]
194+
entities: list[EntityInstanceId[Any, Any]]
195195
The list of entity instance IDs to lock.
196196
197197
Returns
@@ -538,8 +538,8 @@ def task_id(self) -> int:
538538
return self._task_id
539539

540540

541-
# Orchestrators are generators that yield tasks and receive/return any type
542-
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]]
541+
# Orchestrators are generators that yield tasks, recieve any type, and return TOutput
542+
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task[Any], Any, TOutput], TOutput]]
543543

544544
# Activities are simple functions that can be scheduled by orchestrators
545545
Activity = Callable[[ActivityContext, TInput], TOutput]
@@ -615,6 +615,14 @@ def retry_timeout(self) -> Optional[timedelta]:
615615
return self._retry_timeout
616616

617617

618+
def get_entity_name(fn: Entity) -> str:
619+
if hasattr(fn, "__durable_entity_name__"):
620+
return getattr(fn, "__durable_entity_name__")
621+
if isinstance(fn, type) and issubclass(fn, DurableEntity):
622+
return fn.__name__
623+
return get_name(fn)
624+
625+
618626
def get_name(fn: Callable) -> str:
619627
"""Returns the name of the provided function"""
620628
name = fn.__name__

durabletask/worker.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -150,23 +150,23 @@ def __init__(self):
150150
self.entities = {}
151151
self.entity_instances = {}
152152

153-
def add_orchestrator(self, fn: task.Orchestrator) -> str:
153+
def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
154154
if fn is None:
155155
raise ValueError("An orchestrator function argument is required.")
156156

157157
name = task.get_name(fn)
158158
self.add_named_orchestrator(name, fn)
159159
return name
160160

161-
def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
161+
def add_named_orchestrator(self, name: str, fn: task.Orchestrator[TInput, TOutput]) -> None:
162162
if not name:
163163
raise ValueError("A non-empty orchestrator name is required.")
164164
if name in self.orchestrators:
165165
raise ValueError(f"A '{name}' orchestrator already exists.")
166166

167167
self.orchestrators[name] = fn
168168

169-
def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]:
169+
def get_orchestrator(self, name: str) -> Optional[task.Orchestrator[Any, Any]]:
170170
return self.orchestrators.get(name)
171171

172172
def add_activity(self, fn: task.Activity) -> str:
@@ -188,16 +188,13 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None:
188188
def get_activity(self, name: str) -> Optional[task.Activity]:
189189
return self.activities.get(name)
190190

191-
def add_entity(self, fn: task.Entity) -> str:
191+
def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
192192
if fn is None:
193193
raise ValueError("An entity function argument is required.")
194194

195-
if isinstance(fn, type) and issubclass(fn, DurableEntity):
196-
name = fn.__name__
197-
self.add_named_entity(name, fn)
198-
else:
199-
name = task.get_name(fn)
200-
self.add_named_entity(name, fn)
195+
if name is None:
196+
name = task.get_entity_name(fn)
197+
self.add_named_entity(name, fn)
201198
return name
202199

203200
def add_named_entity(self, name: str, fn: task.Entity) -> None:
@@ -207,6 +204,7 @@ def add_named_entity(self, name: str, fn: task.Entity) -> None:
207204
raise ValueError(f"A '{name}' entity already exists.")
208205

209206
self.entities[name] = fn
207+
setattr(fn, "__durable_entity_name__", name)
210208

211209
def get_entity(self, name: str) -> Optional[task.Entity]:
212210
return self.entities.get(name)
@@ -362,7 +360,7 @@ def __enter__(self):
362360
def __exit__(self, type, value, traceback):
363361
self.stop()
364362

365-
def add_orchestrator(self, fn: task.Orchestrator) -> str:
363+
def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
366364
"""Registers an orchestrator function with the worker."""
367365
if self._is_running:
368366
raise RuntimeError(
@@ -378,13 +376,13 @@ def add_activity(self, fn: task.Activity) -> str:
378376
)
379377
return self._registry.add_activity(fn)
380378

381-
def add_entity(self, fn: task.Entity) -> str:
379+
def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
382380
"""Registers an entity function with the worker."""
383381
if self._is_running:
384382
raise RuntimeError(
385383
"Entities cannot be added while the worker is running."
386384
)
387-
return self._registry.add_entity(fn)
385+
return self._registry.add_entity(fn, name)
388386

389387
def use_versioning(self, version: VersioningOptions) -> None:
390388
"""Initializes versioning options for sub-orchestrators and activities."""
@@ -1044,21 +1042,21 @@ def call_activity(
10441042

10451043
def call_entity(
10461044
self,
1047-
entity: EntityInstanceId,
1045+
entity: EntityInstanceId[TInput, TOutput],
10481046
operation: str,
10491047
input: Optional[TInput] = None,
1050-
) -> task.CompletableTask:
1048+
) -> task.CompletableTask[TOutput]:
10511049
id = self.next_sequence_number()
10521050

10531051
self.call_entity_function_helper(
10541052
id, entity, operation, input=input
10551053
)
10561054

1057-
return self._pending_tasks.get(id, task.CompletableTask())
1055+
return self._pending_tasks.get(id, task.CompletableTask[TOutput]())
10581056

10591057
def signal_entity(
10601058
self,
1061-
entity_id: EntityInstanceId,
1059+
entity_id: EntityInstanceId[TInput, TOutput],
10621060
operation_name: str,
10631061
input: Optional[TInput] = None
10641062
) -> None:
@@ -1068,7 +1066,7 @@ def signal_entity(
10681066
id, entity_id, operation_name, input
10691067
)
10701068

1071-
def lock_entities(self, entities: list[EntityInstanceId]) -> task.CompletableTask[EntityLock]:
1069+
def lock_entities(self, entities: list[EntityInstanceId[Any, Any]]) -> task.CompletableTask[EntityLock]:
10721070
id = self.next_sequence_number()
10731071

10741072
self.lock_entities_function_helper(
@@ -1158,11 +1156,11 @@ def call_activity_function_helper(
11581156
def call_entity_function_helper(
11591157
self,
11601158
id: Optional[int],
1161-
entity_id: EntityInstanceId,
1159+
entity_id: EntityInstanceId[TInput, TOutput],
11621160
operation: str,
11631161
*,
11641162
input: Optional[TInput] = None,
1165-
):
1163+
) -> None:
11661164
if id is None:
11671165
id = self.next_sequence_number()
11681166

@@ -1180,7 +1178,7 @@ def call_entity_function_helper(
11801178
def signal_entity_function_helper(
11811179
self,
11821180
id: Optional[int],
1183-
entity_id: EntityInstanceId,
1181+
entity_id: EntityInstanceId[TInput, TOutput],
11841182
operation: str,
11851183
input: Optional[TInput]
11861184
) -> None:
@@ -1197,7 +1195,7 @@ def signal_entity_function_helper(
11971195
action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid())
11981196
self._pending_actions[id] = action
11991197

1200-
def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None:
1198+
def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId[Any, Any]]) -> None:
12011199
if id is None:
12021200
id = self.next_sequence_number()
12031201

@@ -1792,7 +1790,7 @@ def process_event(
17921790
# The orchestrator generator function completed
17931791
ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
17941792

1795-
def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[EntityInstanceId, str]:
1793+
def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[EntityInstanceId[Any, Any], str]:
17961794
try:
17971795
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
17981796
except ValueError:
@@ -1806,7 +1804,7 @@ def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[Entity
18061804
def _handle_entity_event_raised(self,
18071805
ctx: _RuntimeOrchestrationContext,
18081806
event: pb.HistoryEvent,
1809-
entity_id: Optional[EntityInstanceId],
1807+
entity_id: Optional[EntityInstanceId[Any, Any]],
18101808
task_id: Optional[int],
18111809
is_lock_event: bool):
18121810
# This eventRaised represents the result of an entity operation after being translated to the old
@@ -1919,7 +1917,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger):
19191917
def execute(
19201918
self,
19211919
orchestration_id: str,
1922-
entity_id: EntityInstanceId,
1920+
entity_id: EntityInstanceId[TInput, TOutput],
19231921
operation: str,
19241922
state: StateShim,
19251923
encoded_input: Optional[str],

0 commit comments

Comments
 (0)