Skip to content

Commit 7c82f24

Browse files
authored
Improve TimerTask (#122)
* Add cancellable tasks * Implement long timer support * Add DTS default, e2e test * PR Feedback * Lint * Fix merge issues * PR feedback - merge TimerTask/LongTimerTask
1 parent 33dc25a commit 7c82f24

File tree

8 files changed

+640
-43
lines changed

8 files changed

+640
-43
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
2-
ignore = E501,C901
2+
ignore = E501,C901,W503
33
exclude =
44
.git
55
*_pb2*

docs/supported-patterns.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order):
6464
# Orders of $1000 or more require manager approval
6565
yield ctx.call_activity(send_approval_request, input=order)
6666

67-
# Approvals must be received within 24 hours or they will be canceled.
67+
# Approvals must be received within 24 hours or they will be cancelled.
6868
approval_event = ctx.wait_for_external_event("approval_received")
6969
timeout_event = ctx.create_timer(timedelta(hours=24))
7070
winner = yield task.when_any([approval_event, timeout_event])
7171
if winner == timeout_event:
72-
return "Canceled"
72+
return "Cancelled"
7373

7474
# The order was approved
7575
yield ctx.call_activity(place_order, input=order)

durabletask-azuremanaged/durabletask/azuremanaged/worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,7 @@ def __init__(self, *,
8585
log_formatter=log_formatter,
8686
interceptors=interceptors,
8787
concurrency_options=concurrency_options,
88-
payload_store=payload_store)
88+
# DTS natively supports long timers so chunking is unnecessary
89+
maximum_timer_interval=None,
90+
payload_store=payload_store
91+
)

durabletask/task.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def set_custom_status(self, custom_status: Any) -> None:
9898
pass
9999

100100
@abstractmethod
101-
def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
101+
def create_timer(self, fire_at: Union[datetime, timedelta]) -> CancellableTask:
102102
"""Create a Timer Task to fire after at the specified deadline.
103103
104104
Parameters
@@ -228,10 +228,10 @@ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput
228228
"""
229229
pass
230230

231-
# TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
231+
# TOOD: Add a timeout parameter, which allows the task to be cancelled if the event is
232232
# not received within the specified timeout. This requires support for task cancellation.
233233
@abstractmethod
234-
def wait_for_external_event(self, name: str) -> CompletableTask:
234+
def wait_for_external_event(self, name: str) -> CancellableTask:
235235
"""Wait asynchronously for an event to be raised with the name `name`.
236236
237237
Parameters
@@ -324,6 +324,10 @@ class OrchestrationStateError(Exception):
324324
pass
325325

326326

327+
class TaskCancelledError(Exception):
328+
"""Exception type for cancelled orchestration tasks."""
329+
330+
327331
class Task(ABC, Generic[T]):
328332
"""Abstract base class for asynchronous tasks in a durable orchestration."""
329333
_result: T
@@ -435,6 +439,48 @@ def fail(self, message: str, details: Union[Exception, pb.TaskFailureDetails]):
435439
self._parent.on_child_completed(self)
436440

437441

442+
class CancellableTask(CompletableTask[T]):
443+
"""A completable task that can be cancelled before it finishes."""
444+
445+
def __init__(self) -> None:
446+
super().__init__()
447+
self._is_cancelled = False
448+
self._cancel_handler: Optional[Callable[[], None]] = None
449+
450+
@property
451+
def is_cancelled(self) -> bool:
452+
"""Returns True if the task was cancelled, False otherwise."""
453+
return self._is_cancelled
454+
455+
def get_result(self) -> T:
456+
if self._is_cancelled:
457+
raise TaskCancelledError('The task was cancelled.')
458+
return super().get_result()
459+
460+
def set_cancel_handler(self, cancel_handler: Callable[[], None]) -> None:
461+
self._cancel_handler = cancel_handler
462+
463+
def cancel(self) -> bool:
464+
"""Attempts to cancel this task.
465+
466+
Returns
467+
-------
468+
bool
469+
True if cancellation was applied, False if the task had already completed.
470+
"""
471+
if self._is_complete:
472+
return False
473+
474+
if self._cancel_handler is not None:
475+
self._cancel_handler()
476+
477+
self._is_cancelled = True
478+
self._is_complete = True
479+
if self._parent is not None:
480+
self._parent.on_child_completed(self)
481+
return True
482+
483+
438484
class RetryableTask(CompletableTask[T]):
439485
"""A task that can be retried according to a retry policy."""
440486

@@ -474,14 +520,29 @@ def compute_next_delay(self) -> Optional[timedelta]:
474520
return None
475521

476522

477-
class TimerTask(CompletableTask[T]):
478-
479-
def __init__(self) -> None:
523+
class TimerTask(CancellableTask[None]):
524+
def __init__(self, final_fire_at: Optional[datetime] = None,
525+
maximum_timer_interval: Optional[timedelta] = None):
480526
super().__init__()
527+
self._final_fire_at = final_fire_at
528+
self._maximum_timer_interval = maximum_timer_interval
481529

482530
def set_retryable_parent(self, retryable_task: RetryableTask):
483531
self._retryable_parent = retryable_task
484532

533+
def _handle_timer_fired(self, current_utc_datetime: datetime) -> Optional[datetime]:
534+
if (self._final_fire_at is not None
535+
and self._maximum_timer_interval is not None
536+
and current_utc_datetime < self._final_fire_at):
537+
return self._get_next_fire_at(current_utc_datetime)
538+
super().complete(None)
539+
return None
540+
541+
def _get_next_fire_at(self, current_utc_datetime: datetime) -> datetime:
542+
if current_utc_datetime + self._maximum_timer_interval < self._final_fire_at:
543+
return current_utc_datetime + self._maximum_timer_interval
544+
return self._final_fire_at
545+
485546

486547
class WhenAnyTask(CompositeTask[Task]):
487548
"""A task that completes when any of its child tasks complete."""

0 commit comments

Comments
 (0)